diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml index a35fec394d23..2acbe66f4caf 100644 --- a/.buildkite/postgres-config.yaml +++ b/.buildkite/postgres-config.yaml @@ -1,7 +1,7 @@ # Configuration file used for testing the 'synapse_port_db' script. # Tells the script to connect to the postgresql database that will be available in the # CI's Docker setup at the point where this file is considered. -server_name: "test" +server_name: "localhost:8800" signing_key_path: "/src/.buildkite/test.signing.key" diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml index 635b92176469..6d9bf80d844b 100644 --- a/.buildkite/sqlite-config.yaml +++ b/.buildkite/sqlite-config.yaml @@ -1,7 +1,7 @@ # Configuration file used for testing the 'synapse_port_db' script. # Tells the 'update_database' script to connect to the test SQLite database to upgrade its # schema and run background updates on it. -server_name: "test" +server_name: "localhost:8800" signing_key_path: "/src/.buildkite/test.signing.key" diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8939fda67d5d..11fb05ca966b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,8 +1,8 @@ ### Pull Request Checklist - + * [ ] Pull request is based on the develop branch -* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog) -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off) -* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style)) +* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#changelog) +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#sign-off) +* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#code-style)) diff --git a/CHANGES.md b/CHANGES.md index a9afd36d2c92..0ef9794aac26 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,81 @@ +Synapse 1.7.0rc1 (2019-12-09) +============================= + +Features +-------- + +- Implement per-room message retention policies. ([\#5815](https://github.com/matrix-org/synapse/issues/5815)) +- Add etag and count fields to key backup endpoints to help clients guess if there are new keys. ([\#5858](https://github.com/matrix-org/synapse/issues/5858)) +- Add admin/v2/users endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925)) +- Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. ([\#6119](https://github.com/matrix-org/synapse/issues/6119)) +- Implement the `/_matrix/federation/unstable/net.atleastfornow/state/` API as drafted in MSC2314. ([\#6176](https://github.com/matrix-org/synapse/issues/6176)) +- Configure privacy preserving settings by default for the room directory. ([\#6354](https://github.com/matrix-org/synapse/issues/6354)) +- Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). ([\#6409](https://github.com/matrix-org/synapse/issues/6409)) +- Add support for MSC 2367, which allows specifying a reason on all membership events. ([\#6434](https://github.com/matrix-org/synapse/issues/6434)) + + +Bugfixes +-------- + +- Transfer non-standard power levels on room upgrade. ([\#6237](https://github.com/matrix-org/synapse/issues/6237)) +- Fix error from the Pillow library when uploading RGBA images. ([\#6241](https://github.com/matrix-org/synapse/issues/6241)) +- Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests. ([\#6329](https://github.com/matrix-org/synapse/issues/6329)) +- Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. ([\#6332](https://github.com/matrix-org/synapse/issues/6332)) +- Prevent account data syncs getting lost across TCP replication. ([\#6333](https://github.com/matrix-org/synapse/issues/6333)) +- Fix bug: TypeError in `register_user()` while using LDAP auth module. ([\#6406](https://github.com/matrix-org/synapse/issues/6406)) +- Fix an intermittent exception when handling read-receipts. ([\#6408](https://github.com/matrix-org/synapse/issues/6408)) +- Fix broken guest registration when there are existing blocks of numeric user IDs. ([\#6420](https://github.com/matrix-org/synapse/issues/6420)) +- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421)) +- Clean up local threepids from user on account deactivation. ([\#6426](https://github.com/matrix-org/synapse/issues/6426)) +- Fix a bug where a room could become unusable with a low retention policy and a low activity. ([\#6436](https://github.com/matrix-org/synapse/issues/6436)) +- Fix error when using synapse_port_db on a vanilla synapse db. ([\#6449](https://github.com/matrix-org/synapse/issues/6449)) +- Fix uploading multiple cross signing signatures for the same user. ([\#6451](https://github.com/matrix-org/synapse/issues/6451)) +- Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted. ([\#6462](https://github.com/matrix-org/synapse/issues/6462)) +- Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process. ([\#6470](https://github.com/matrix-org/synapse/issues/6470)) +- Improve sanity-checking when receiving events over federation. ([\#6472](https://github.com/matrix-org/synapse/issues/6472)) +- Fix inaccurate per-block Prometheus metrics. ([\#6491](https://github.com/matrix-org/synapse/issues/6491)) +- Fix small performance regression for sending invites. ([\#6493](https://github.com/matrix-org/synapse/issues/6493)) +- Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. ([\#6494](https://github.com/matrix-org/synapse/issues/6494)) + + +Improved Documentation +---------------------- + +- Update documentation and variables in user contributed systemd reference file. ([\#6369](https://github.com/matrix-org/synapse/issues/6369), [\#6490](https://github.com/matrix-org/synapse/issues/6490)) +- Fix link in the user directory documentation. ([\#6388](https://github.com/matrix-org/synapse/issues/6388)) +- Add build instructions to the docker readme. ([\#6390](https://github.com/matrix-org/synapse/issues/6390)) +- Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md. ([\#6443](https://github.com/matrix-org/synapse/issues/6443)) +- Write some docs for the quarantine_media api. ([\#6458](https://github.com/matrix-org/synapse/issues/6458)) +- Convert CONTRIBUTING.rst to markdown (among other small fixes). ([\#6461](https://github.com/matrix-org/synapse/issues/6461)) + + +Deprecations and Removals +------------------------- + +- Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925)) +- Remove fallback for federation with old servers which lack the /federation/v1/state_ids API. ([\#6488](https://github.com/matrix-org/synapse/issues/6488)) + + +Internal Changes +---------------- + +- Add benchmarks for structured logging and improve output performance. ([\#6266](https://github.com/matrix-org/synapse/issues/6266)) +- Improve the performance of outputting structured logging. ([\#6322](https://github.com/matrix-org/synapse/issues/6322)) +- Refactor some code in the event authentication path for clarity. ([\#6343](https://github.com/matrix-org/synapse/issues/6343), [\#6468](https://github.com/matrix-org/synapse/issues/6468), [\#6480](https://github.com/matrix-org/synapse/issues/6480)) +- Clean up some unnecessary quotation marks around the codebase. ([\#6362](https://github.com/matrix-org/synapse/issues/6362)) +- Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. ([\#6379](https://github.com/matrix-org/synapse/issues/6379)) +- Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392)) +- Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423)) +- Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429)) +- Move data store specific code out of `SQLBaseStore`. ([\#6454](https://github.com/matrix-org/synapse/issues/6454)) +- Prepare SQLBaseStore functions being moved out of the stores. ([\#6464](https://github.com/matrix-org/synapse/issues/6464)) +- Move per database functionality out of the data stores and into a dedicated `Database` class. ([\#6469](https://github.com/matrix-org/synapse/issues/6469)) +- Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482)) +- Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483)) +- Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484)) +- Pass in `Database` object to data stores. ([\#6487](https://github.com/matrix-org/synapse/issues/6487)) + + Synapse 1.6.1 (2019-11-28) ========================== diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000000..c0091346f3b6 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,210 @@ +# Contributing code to Matrix + +Everyone is welcome to contribute code to Matrix +(https://github.com/matrix-org), provided that they are willing to license +their contributions under the same license as the project itself. We follow a +simple 'inbound=outbound' model for contributions: the act of submitting an +'inbound' contribution means that the contributor agrees to license the code +under the same terms as the project's overall 'outbound' license - in our +case, this is almost always Apache Software License v2 (see [LICENSE](LICENSE)). + +## How to contribute + +The preferred and easiest way to contribute changes to Matrix is to fork the +relevant project on github, and then [create a pull request]( +https://help.github.com/articles/using-pull-requests/) to ask us to pull +your changes into our repo. + +**The single biggest thing you need to know is: please base your changes on +the develop branch - *not* master.** + +We use the master branch to track the most recent release, so that folks who +blindly clone the repo and automatically check out master get something that +works. Develop is the unstable branch where all the development actually +happens: the workflow is that contributors should fork the develop branch to +make a 'feature' branch for a particular contribution, and then make a pull +request to merge this back into the matrix.org 'official' develop branch. We +use github's pull request workflow to review the contribution, and either ask +you to make any refinements needed or merge it and make them ourselves. The +changes will then land on master when we next do a release. + +We use [Buildkite](https://buildkite.com/matrix-dot-org/synapse) for continuous +integration. If your change breaks the build, this will be shown in GitHub, so +please keep an eye on the pull request for feedback. + +To run unit tests in a local development environment, you can use: + +- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) + for SQLite-backed Synapse on Python 3.5. +- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. +- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 + (requires a running local PostgreSQL with access to create databases). +- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 + (requires Docker). Entirely self-contained, recommended if you don't want to + set up PostgreSQL yourself. + +Docker images are available for running the integration tests (SyTest) locally, +see the [documentation in the SyTest repo]( +https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more +information. + +## Code style + +All Matrix projects have a well-defined code-style - and sometimes we've even +got as far as documenting it... For instance, synapse's code style doc lives +[here](docs/code_style.md). + +To facilitate meeting these criteria you can run `scripts-dev/lint.sh` +locally. Since this runs the tools listed in the above document, you'll need +python 3.6 and to install each tool: + +``` +# Install the dependencies +pip install -U black flake8 isort + +# Run the linter script +./scripts-dev/lint.sh +``` + +**Note that the script does not just test/check, but also reformats code, so you +may wish to ensure any new code is committed first**. By default this script +checks all files and can take some time; if you alter only certain files, you +might wish to specify paths as arguments to reduce the run-time: + +``` +./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder +``` + +Before pushing new changes, ensure they don't produce linting errors. Commit any +files that were corrected. + +Please ensure your changes match the cosmetic style of the existing project, +and **never** mix cosmetic and functional changes in the same commit, as it +makes it horribly hard to review otherwise. + + +## Changelog + +All changes, even minor ones, need a corresponding changelog / newsfragment +entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier). + +To create a changelog entry, make a new file in the `changelog.d` directory named +in the format of `PRnumber.type`. The type can be one of the following: + +* `feature` +* `bugfix` +* `docker` (for updates to the Docker image) +* `doc` (for updates to the documentation) +* `removal` (also used for deprecations) +* `misc` (for internal-only changes) + +The content of the file is your changelog entry, which should be a short +description of your change in the same style as the rest of our [changelog]( +https://github.com/matrix-org/synapse/blob/master/CHANGES.md). The file can +contain Markdown formatting, and should end with a full stop ('.') for +consistency. + +Adding credits to the changelog is encouraged, we value your +contributions and would like to have you shouted out in the release notes! + +For example, a fix in PR #1234 would have its changelog entry in +`changelog.d/1234.bugfix`, and contain content like "The security levels of +Florbs are now validated when received over federation. Contributed by Jane +Matrix.". + +## Debian changelog + +Changes which affect the debian packaging files (in `debian`) are an +exception. + +In this case, you will need to add an entry to the debian changelog for the +next release. For this, run the following command: + +``` +dch +``` + +This will make up a new version number (if there isn't already an unreleased +version in flight), and open an editor where you can add a new changelog entry. +(Our release process will ensure that the version number and maintainer name is +corrected for the release.) + +If your change affects both the debian packaging *and* files outside the debian +directory, you will need both a regular newsfragment *and* an entry in the +debian changelog. (Though typically such changes should be submitted as two +separate pull requests.) + +## Sign off + +In order to have a concrete record that your contribution is intentional +and you agree to license it under the same terms as the project's license, we've adopted the +same lightweight approach that the Linux Kernel +[submitting patches process]( +https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>), +[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other +projects use: the DCO (Developer Certificate of Origin: +http://developercertificate.org/). This is a simple declaration that you wrote +the contribution or otherwise have the right to contribute it to Matrix: + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +If you agree to this for your contribution, then all that's needed is to +include the line in your commit or pull request comment: + +``` +Signed-off-by: Your Name +``` + +We accept contributions under a legally identifiable name, such as +your name on government documentation or common-law names (names +claimed by legitimate usage or repute). Unfortunately, we cannot +accept anonymous contributions at this time. + +Git allows you to add this signoff automatically when using the `-s` +flag to `git commit`, which uses the name and email set in your +`user.name` and `user.email` git configs. + +## Conclusion + +That's it! Matrix is a very open and collaborative project as you might expect +given our obsession with open communication. If we're going to successfully +matrix together all the fragmented communication technologies out there we are +reliant on contributions and collaboration from the community to do so. So +please get involved - and we hope you have as much fun hacking on Matrix as we +do! diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst deleted file mode 100644 index df81f6e54f9d..000000000000 --- a/CONTRIBUTING.rst +++ /dev/null @@ -1,206 +0,0 @@ -Contributing code to Matrix -=========================== - -Everyone is welcome to contribute code to Matrix -(https://github.com/matrix-org), provided that they are willing to license -their contributions under the same license as the project itself. We follow a -simple 'inbound=outbound' model for contributions: the act of submitting an -'inbound' contribution means that the contributor agrees to license the code -under the same terms as the project's overall 'outbound' license - in our -case, this is almost always Apache Software License v2 (see LICENSE). - -How to contribute -~~~~~~~~~~~~~~~~~ - -The preferred and easiest way to contribute changes to Matrix is to fork the -relevant project on github, and then create a pull request to ask us to pull -your changes into our repo -(https://help.github.com/articles/using-pull-requests/) - -**The single biggest thing you need to know is: please base your changes on -the develop branch - /not/ master.** - -We use the master branch to track the most recent release, so that folks who -blindly clone the repo and automatically check out master get something that -works. Develop is the unstable branch where all the development actually -happens: the workflow is that contributors should fork the develop branch to -make a 'feature' branch for a particular contribution, and then make a pull -request to merge this back into the matrix.org 'official' develop branch. We -use github's pull request workflow to review the contribution, and either ask -you to make any refinements needed or merge it and make them ourselves. The -changes will then land on master when we next do a release. - -We use `Buildkite `_ for -continuous integration. Buildkite builds need to be authorised by a -maintainer. If your change breaks the build, this will be shown in GitHub, so -please keep an eye on the pull request for feedback. - -To run unit tests in a local development environment, you can use: - -- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) - for SQLite-backed Synapse on Python 3.5. -- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. -- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 - (requires a running local PostgreSQL with access to create databases). -- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 - (requires Docker). Entirely self-contained, recommended if you don't want to - set up PostgreSQL yourself. - -Docker images are available for running the integration tests (SyTest) locally, -see the `documentation in the SyTest repo -`_ for more -information. - -Code style -~~~~~~~~~~ - -All Matrix projects have a well-defined code-style - and sometimes we've even -got as far as documenting it... For instance, synapse's code style doc lives -at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md. - -To facilitate meeting these criteria you can run ``scripts-dev/lint.sh`` -locally. Since this runs the tools listed in the above document, you'll need -python 3.6 and to install each tool. **Note that the script does not just -test/check, but also reformats code, so you may wish to ensure any new code is -committed first**. By default this script checks all files and can take some -time; if you alter only certain files, you might wish to specify paths as -arguments to reduce the run-time. - -Please ensure your changes match the cosmetic style of the existing project, -and **never** mix cosmetic and functional changes in the same commit, as it -makes it horribly hard to review otherwise. - -Before doing a commit, ensure the changes you've made don't produce -linting errors. You can do this by running the linters as follows. Ensure to -commit any files that were corrected. - -:: - # Install the dependencies - pip install -U black flake8 isort - - # Run the linter script - ./scripts-dev/lint.sh - -Changelog -~~~~~~~~~ - -All changes, even minor ones, need a corresponding changelog / newsfragment -entry. These are managed by Towncrier -(https://github.com/hawkowl/towncrier). - -To create a changelog entry, make a new file in the ``changelog.d`` file named -in the format of ``PRnumber.type``. The type can be one of the following: - -* ``feature``. -* ``bugfix``. -* ``docker`` (for updates to the Docker image). -* ``doc`` (for updates to the documentation). -* ``removal`` (also used for deprecations). -* ``misc`` (for internal-only changes). - -The content of the file is your changelog entry, which should be a short -description of your change in the same style as the rest of our `changelog -`_. The file can -contain Markdown formatting, and should end with a full stop ('.') for -consistency. - -Adding credits to the changelog is encouraged, we value your -contributions and would like to have you shouted out in the release notes! - -For example, a fix in PR #1234 would have its changelog entry in -``changelog.d/1234.bugfix``, and contain content like "The security levels of -Florbs are now validated when recieved over federation. Contributed by Jane -Matrix.". - -Debian changelog ----------------- - -Changes which affect the debian packaging files (in ``debian``) are an -exception. - -In this case, you will need to add an entry to the debian changelog for the -next release. For this, run the following command:: - - dch - -This will make up a new version number (if there isn't already an unreleased -version in flight), and open an editor where you can add a new changelog entry. -(Our release process will ensure that the version number and maintainer name is -corrected for the release.) - -If your change affects both the debian packaging *and* files outside the debian -directory, you will need both a regular newsfragment *and* an entry in the -debian changelog. (Though typically such changes should be submitted as two -separate pull requests.) - -Sign off -~~~~~~~~ - -In order to have a concrete record that your contribution is intentional -and you agree to license it under the same terms as the project's license, we've adopted the -same lightweight approach that the Linux Kernel -`submitting patches process `_, Docker -(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other -projects use: the DCO (Developer Certificate of Origin: -http://developercertificate.org/). This is a simple declaration that you wrote -the contribution or otherwise have the right to contribute it to Matrix:: - - Developer Certificate of Origin - Version 1.1 - - Copyright (C) 2004, 2006 The Linux Foundation and its contributors. - 660 York Street, Suite 102, - San Francisco, CA 94110 USA - - Everyone is permitted to copy and distribute verbatim copies of this - license document, but changing it is not allowed. - - Developer's Certificate of Origin 1.1 - - By making a contribution to this project, I certify that: - - (a) The contribution was created in whole or in part by me and I - have the right to submit it under the open source license - indicated in the file; or - - (b) The contribution is based upon previous work that, to the best - of my knowledge, is covered under an appropriate open source - license and I have the right under that license to submit that - work with modifications, whether created in whole or in part - by me, under the same open source license (unless I am - permitted to submit under a different license), as indicated - in the file; or - - (c) The contribution was provided directly to me by some other - person who certified (a), (b) or (c) and I have not modified - it. - - (d) I understand and agree that this project and the contribution - are public and that a record of the contribution (including all - personal information I submit with it, including my sign-off) is - maintained indefinitely and may be redistributed consistent with - this project or the open source license(s) involved. - -If you agree to this for your contribution, then all that's needed is to -include the line in your commit or pull request comment:: - - Signed-off-by: Your Name - -We accept contributions under a legally identifiable name, such as -your name on government documentation or common-law names (names -claimed by legitimate usage or repute). Unfortunately, we cannot -accept anonymous contributions at this time. - -Git allows you to add this signoff automatically when using the ``-s`` -flag to ``git commit``, which uses the name and email set in your -``user.name`` and ``user.email`` git configs. - -Conclusion -~~~~~~~~~~ - -That's it! Matrix is a very open and collaborative project as you might expect -given our obsession with open communication. If we're going to successfully -matrix together all the fragmented communication technologies out there we are -reliant on contributions and collaboration from the community to do so. So -please get involved - and we hope you have as much fun hacking on Matrix as we -do! diff --git a/INSTALL.md b/INSTALL.md index 9b7360f0efba..9da2e3c734e2 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -109,8 +109,8 @@ Installing prerequisites on Ubuntu or Debian: ``` sudo apt-get install build-essential python3-dev libffi-dev \ - python-pip python-setuptools sqlite3 \ - libssl-dev python-virtualenv libjpeg-dev libxslt1-dev + python3-pip python3-setuptools sqlite3 \ + libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev ``` #### ArchLinux diff --git a/UPGRADE.rst b/UPGRADE.rst index 5ebf16a73e77..d9020f2663ae 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,23 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.7.0 +=================== + +In an attempt to configure Synapse in a privacy preserving way, the default +behaviours of ``allow_public_rooms_without_auth`` and +``allow_public_rooms_over_federation`` have been inverted. This means that by +default, only authenticated users querying the Client/Server API will be able +to query the room directory, and relatedly that the server will not share +room directory information with other servers over federation. + +If your installation does not explicitly set these settings one way or the other +and you want either setting to be ``true`` then it will necessary to update +your homeserver configuration file accordingly. + +For more details on the surrounding context see our `explainer +`_. + Upgrading to v1.5.0 =================== diff --git a/changelog.d/5815.feature b/changelog.d/5815.feature deleted file mode 100644 index ca4df4e7f66d..000000000000 --- a/changelog.d/5815.feature +++ /dev/null @@ -1 +0,0 @@ -Implement per-room message retention policies. diff --git a/changelog.d/5858.feature b/changelog.d/5858.feature deleted file mode 100644 index 55ee93051e37..000000000000 --- a/changelog.d/5858.feature +++ /dev/null @@ -1 +0,0 @@ -Add etag and count fields to key backup endpoints to help clients guess if there are new keys. diff --git a/changelog.d/6119.feature b/changelog.d/6119.feature deleted file mode 100644 index 1492e83c5a9b..000000000000 --- a/changelog.d/6119.feature +++ /dev/null @@ -1 +0,0 @@ -Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. \ No newline at end of file diff --git a/changelog.d/6176.feature b/changelog.d/6176.feature deleted file mode 100644 index 3c66d689d4a4..000000000000 --- a/changelog.d/6176.feature +++ /dev/null @@ -1 +0,0 @@ -Implement the `/_matrix/federation/unstable/net.atleastfornow/state/` API as drafted in MSC2314. diff --git a/changelog.d/6322.misc b/changelog.d/6322.misc deleted file mode 100644 index 70ef36ca806b..000000000000 --- a/changelog.d/6322.misc +++ /dev/null @@ -1 +0,0 @@ -Improve the performance of outputting structured logging. diff --git a/changelog.d/6332.bugfix b/changelog.d/6332.bugfix deleted file mode 100644 index 67d5170ba0de..000000000000 --- a/changelog.d/6332.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. diff --git a/changelog.d/6333.bugfix b/changelog.d/6333.bugfix deleted file mode 100644 index a25d6ef3cb82..000000000000 --- a/changelog.d/6333.bugfix +++ /dev/null @@ -1 +0,0 @@ -Prevent account data syncs getting lost across TCP replication. \ No newline at end of file diff --git a/changelog.d/6343.misc b/changelog.d/6343.misc deleted file mode 100644 index d9a44389b976..000000000000 --- a/changelog.d/6343.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor some code in the event authentication path for clarity. diff --git a/changelog.d/6362.misc b/changelog.d/6362.misc deleted file mode 100644 index b79a5bea9920..000000000000 --- a/changelog.d/6362.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some unnecessary quotation marks around the codebase. \ No newline at end of file diff --git a/changelog.d/6379.misc b/changelog.d/6379.misc deleted file mode 100644 index 725c2e7d8744..000000000000 --- a/changelog.d/6379.misc +++ /dev/null @@ -1 +0,0 @@ -Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. \ No newline at end of file diff --git a/changelog.d/6388.doc b/changelog.d/6388.doc deleted file mode 100644 index c777cb6b8f12..000000000000 --- a/changelog.d/6388.doc +++ /dev/null @@ -1 +0,0 @@ -Fix link in the user directory documentation. diff --git a/changelog.d/6390.doc b/changelog.d/6390.doc deleted file mode 100644 index 093411bec140..000000000000 --- a/changelog.d/6390.doc +++ /dev/null @@ -1 +0,0 @@ -Add build instructions to the docker readme. \ No newline at end of file diff --git a/changelog.d/6392.misc b/changelog.d/6392.misc deleted file mode 100644 index a00257944f2c..000000000000 --- a/changelog.d/6392.misc +++ /dev/null @@ -1 +0,0 @@ -Add a test scenario to make sure room history purges don't break `/messages` in the future. diff --git a/changelog.d/6408.bugfix b/changelog.d/6408.bugfix deleted file mode 100644 index c9babe599b7e..000000000000 --- a/changelog.d/6408.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix an intermittent exception when handling read-receipts. diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature new file mode 100644 index 000000000000..ebea4a208d31 --- /dev/null +++ b/changelog.d/6411.feature @@ -0,0 +1 @@ +Allow custom SAML username mapping functinality through an external provider plugin. \ No newline at end of file diff --git a/changelog.d/6420.bugfix b/changelog.d/6420.bugfix deleted file mode 100644 index aef47cccaaa8..000000000000 --- a/changelog.d/6420.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix broken guest registration when there are existing blocks of numeric user IDs. diff --git a/changelog.d/6421.bugfix b/changelog.d/6421.bugfix deleted file mode 100644 index 7969f7f71dc3..000000000000 --- a/changelog.d/6421.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix startup error when http proxy is defined. diff --git a/changelog.d/6423.misc b/changelog.d/6423.misc deleted file mode 100644 index 9bcd5d36c132..000000000000 --- a/changelog.d/6423.misc +++ /dev/null @@ -1 +0,0 @@ -Clarifications for the email configuration settings. diff --git a/changelog.d/6426.bugfix b/changelog.d/6426.bugfix deleted file mode 100644 index 3acfde421116..000000000000 --- a/changelog.d/6426.bugfix +++ /dev/null @@ -1 +0,0 @@ -Clean up local threepids from user on account deactivation. \ No newline at end of file diff --git a/changelog.d/6429.misc b/changelog.d/6429.misc deleted file mode 100644 index 4b32cdeac6aa..000000000000 --- a/changelog.d/6429.misc +++ /dev/null @@ -1 +0,0 @@ -Add more tests to the blacklist when running in worker mode. diff --git a/changelog.d/6434.feature b/changelog.d/6434.feature deleted file mode 100644 index affa5d50c1cd..000000000000 --- a/changelog.d/6434.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for MSC 2367, which allows specifying a reason on all membership events. diff --git a/changelog.d/6436.bugfix b/changelog.d/6436.bugfix deleted file mode 100644 index 954a4e1d84f9..000000000000 --- a/changelog.d/6436.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where a room could become unusable with a low retention policy and a low activity. diff --git a/changelog.d/6497.bugfix b/changelog.d/6497.bugfix new file mode 100644 index 000000000000..92ed08fc4077 --- /dev/null +++ b/changelog.d/6497.bugfix @@ -0,0 +1 @@ +Fix error message when setting your profile's avatar URL mentioning displaynames, and prevent NoneType avatar_urls. \ No newline at end of file diff --git a/changelog.d/6499.bugfix b/changelog.d/6499.bugfix new file mode 100644 index 000000000000..299feba0f82f --- /dev/null +++ b/changelog.d/6499.bugfix @@ -0,0 +1 @@ +Fix support for SQLite 3.7. diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc new file mode 100644 index 000000000000..255f45a9c3d5 --- /dev/null +++ b/changelog.d/6501.misc @@ -0,0 +1 @@ +Refactor get_events_from_store_or_dest to return a dict. diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal new file mode 100644 index 000000000000..0b72261d58b1 --- /dev/null +++ b/changelog.d/6502.removal @@ -0,0 +1 @@ +Remove redundant code from event authorisation implementation. diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc new file mode 100644 index 000000000000..e4e9a5a3d4e5 --- /dev/null +++ b/changelog.d/6503.misc @@ -0,0 +1 @@ +Move get_state methods into FederationHandler. diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc new file mode 100644 index 000000000000..3a75b2d9dd23 --- /dev/null +++ b/changelog.d/6505.misc @@ -0,0 +1 @@ +Make `make_deferred_yieldable` to work with async/await. diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc new file mode 100644 index 000000000000..99d7a70bcf0a --- /dev/null +++ b/changelog.d/6506.misc @@ -0,0 +1 @@ +Remove `SnapshotCache` in favour of `ResponseCache`. diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc new file mode 100644 index 000000000000..214f06539b02 --- /dev/null +++ b/changelog.d/6510.misc @@ -0,0 +1 @@ +Change phone home stats to not assume there is a single database and report information about the database used by the main data store. diff --git a/changelog.d/6512.misc b/changelog.d/6512.misc new file mode 100644 index 000000000000..37a8099eec38 --- /dev/null +++ b/changelog.d/6512.misc @@ -0,0 +1 @@ +Silence mypy errors for files outside those specified. diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix new file mode 100644 index 000000000000..6dc1985c2404 --- /dev/null +++ b/changelog.d/6514.bugfix @@ -0,0 +1 @@ +Fix race which occasionally caused deleted devices to reappear. diff --git a/contrib/systemd/README.md b/contrib/systemd/README.md new file mode 100644 index 000000000000..5d42b3464f68 --- /dev/null +++ b/contrib/systemd/README.md @@ -0,0 +1,17 @@ +# Setup Synapse with Systemd +This is a setup for managing synapse with a user contributed systemd unit +file. It provides a `matrix-synapse` systemd unit file that should be tailored +to accommodate your installation in accordance with the installation +instructions provided in [installation instructions](../../INSTALL.md). + +## Setup +1. Under the service section, ensure the `User` variable matches which user +you installed synapse under and wish to run it as. +2. Under the service section, ensure the `WorkingDirectory` variable matches +where you have installed synapse. +3. Under the service section, ensure the `ExecStart` variable matches the +appropriate locations of your installation. +4. Copy the `matrix-synapse.service` to `/etc/systemd/system/` +5. Start Synapse: `sudo systemctl start matrix-synapse` +6. Verify Synapse is running: `sudo systemctl status matrix-synapse` +7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse` diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service index 38d369ea3d49..813717b032d7 100644 --- a/contrib/systemd/matrix-synapse.service +++ b/contrib/systemd/matrix-synapse.service @@ -4,8 +4,11 @@ # systemctl enable matrix-synapse # systemctl start matrix-synapse # +# This assumes that Synapse has been installed by a user named +# synapse. +# # This assumes that Synapse has been installed in a virtualenv in -# /opt/synapse/env. +# the user's home directory: `/home/synapse/synapse/env`. # # **NOTE:** This is an example service file that may change in the future. If you # wish to use this please copy rather than symlink it. @@ -22,8 +25,8 @@ Restart=on-abort User=synapse Group=nogroup -WorkingDirectory=/opt/synapse -ExecStart=/opt/synapse/env/bin/python -m synapse.app.homeserver --config-path=/opt/synapse/homeserver.yaml +WorkingDirectory=/home/synapse/synapse +ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml SyslogIdentifier=matrix-synapse # adjust the cache factor if necessary diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 5e9f8e5d843f..8b3666d5f588 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -21,3 +21,20 @@ It returns a JSON body like the following: ] } ``` + +# Quarantine media in a room + +This API 'quarantines' all the media in a room. + +The API is: + +``` +POST /_synapse/admin/v1/quarantine_media/ + +{} +``` + +Quarantining media means that it is marked as inaccessible by users. It applies +to any local media, and any locally-cached copies of remote media. + +The media file itself (and any thumbnails) is not deleted from the server. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index d0871f943844..b451dc501401 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -1,3 +1,48 @@ +List Accounts +============= + +This API returns all local user accounts. + +The api is:: + + GET /_synapse/admin/v2/users?from=0&limit=10&guests=false + +including an ``access_token`` of a server admin. +The parameters ``from`` and ``limit`` are required only for pagination. +By default, a ``limit`` of 100 is used. +The parameter ``user_id`` can be used to select only users with user ids that +contain this value. +The parameter ``guests=false`` can be used to exclude guest users, +default is to include guest users. +The parameter ``deactivated=true`` can be used to include deactivated users, +default is to exclude deactivated users. +If the endpoint does not return a ``next_token`` then there are no more users left. +It returns a JSON body like the following: + +.. code:: json + + { + "users": [ + { + "name": "", + "password_hash": "", + "is_guest": 0, + "admin": 0, + "user_type": null, + "deactivated": 0 + }, { + "name": "", + "password_hash": "", + "is_guest": 0, + "admin": 1, + "user_type": null, + "deactivated": 0 + } + ], + "next_token": "100" + } + + Query Account ============= diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md new file mode 100644 index 000000000000..92f2380488c3 --- /dev/null +++ b/docs/saml_mapping_providers.md @@ -0,0 +1,77 @@ +# SAML Mapping Providers + +A SAML mapping provider is a Python class (loaded via a Python module) that +works out how to map attributes of a SAML response object to Matrix-specific +user attributes. Details such as user ID localpart, displayname, and even avatar +URLs are all things that can be mapped from talking to a SSO service. + +As an example, a SSO service may return the email address +"john.smith@example.com" for a user, whereas Synapse will need to figure out how +to turn that into a displayname when creating a Matrix user for this individual. +It may choose `John Smith`, or `Smith, John [Example.com]` or any number of +variations. As each Synapse configuration may want something different, this is +where SAML mapping providers come into play. + +## Enabling Providers + +External mapping providers are provided to Synapse in the form of an external +Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere, +then tell Synapse where to look for the handler class by editing the +`saml2_config.user_mapping_provider.module` config option. + +`saml2_config.user_mapping_provider.config` allows you to provide custom +configuration options to the module. Check with the module's documentation for +what options it provides (if any). The options listed by default are for the +user mapping provider built in to Synapse. If using a custom module, you should +comment these options out and use those specified by the module instead. + +## Building a Custom Mapping Provider + +A custom mapping provider must specify the following methods: + +* `__init__(self, parsed_config)` + - Arguments: + - `parsed_config` - A configuration object that is the return value of the + `parse_config` method. You should set any configuration options needed by + the module here. +* `saml_response_to_user_attributes(self, saml_response, failures)` + - Arguments: + - `saml_response` - A `saml2.response.AuthnResponse` object to extract user + information from. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `mxid_localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `mxid_localpart` value, such as `john.doe1`. + - This method must return a dictionary, which will then be used by Synapse + to build a new user. The following keys are allowed: + * `mxid_localpart` - Required. The mxid localpart of the new user. + * `displayname` - The displayname of the new user. If not provided, will default to + the value of `mxid_localpart`. +* `parse_config(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A `dict` representing the parsed content of the + `saml2_config.user_mapping_provider.config` homeserver config option. + Runs on homeserver startup. Providers should extract any option values + they need here. + - Whatever is returned will be passed back to the user mapping provider module's + `__init__` method during construction. +* `get_saml_attributes(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A object resulting from a call to `parse_config`. + - Returns a tuple of two sets. The first set equates to the saml auth + response attributes that are required for the module to function, whereas + the second set consists of those attributes which can be used if available, + but are not necessary. + +## Synapse's Default Provider + +Synapse has a built-in SAML mapping provider if a custom provider isn't +specified in the config. It is located at +[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c7391f0c485e..4d44e631d14f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,15 +54,16 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'false', requires authentication to access the server's public rooms -# directory through the client API. Defaults to 'true'. +# If set to 'true', removes the need for authentication to access the server's +# public rooms directory through the client API, meaning that anyone can +# query the room directory. Defaults to 'false'. # -#allow_public_rooms_without_auth: false +#allow_public_rooms_without_auth: true -# If set to 'false', forbids any other homeserver to fetch the server's public -# rooms directory via federation. Defaults to 'true'. +# If set to 'true', allows any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'false'. # -#allow_public_rooms_over_federation: false +#allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -1249,33 +1250,58 @@ saml2_config: # #config_path: "CONFDIR/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # diff --git a/mypy.ini b/mypy.ini index 1d77c0ecc842..a66434b76b7f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ [mypy] namespace_packages = True plugins = mypy_zope:plugin -follow_imports = normal +follow_imports = silent check_untyped_defs = True show_error_codes = True show_traceback = True diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index d20f6db17656..bf3862a38653 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -27,7 +27,7 @@ class Store(object): "_store_pdu_reference_hash_txn" ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + simple_insert_txn = SQLBaseStore.__dict__["simple_insert_txn"] store = Store() diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 27a1ad1e7e43..1776d202c58a 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -58,10 +58,10 @@ if __name__ == "__main__": " on it." ) ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--database-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="A database config file for either a SQLite3 database or a PostgreSQL one.", ) @@ -101,10 +101,7 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver( - config, - database_engine, - db_conn, - db_config=config.database_config, + config, database_engine, db_conn, db_config=config.database_config, ) # setup instantiates the store within the homeserver object. hs.setup() @@ -112,13 +109,13 @@ if __name__ == "__main__": @defer.inlineCallbacks def run_background_updates(): - yield store.run_background_updates(sleep=False) + yield store.db.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() # Apply all background updates on the database. - reactor.callWhenRunning(lambda: run_as_background_process( - "background_updates", run_background_updates - )) + reactor.callWhenRunning( + lambda: run_as_background_process("background_updates", run_background_updates) + ) reactor.run() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 0d3321682c6e..e393a9b2f7c4 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -47,6 +47,7 @@ from synapse.storage.data_stores.main.media_repository import ( from synapse.storage.data_stores.main.registration import ( RegistrationBackgroundUpdateStore, ) +from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore @@ -54,6 +55,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -131,54 +133,22 @@ class Store( EventsBackgroundUpdatesStore, MediaRepositoryBackgroundUpdateStore, RegistrationBackgroundUpdateStore, + RoomBackgroundUpdateStore, RoomMemberBackgroundUpdateStore, SearchBackgroundUpdateStore, StateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, StatsStore, ): - def __init__(self, db_conn, hs): - super().__init__(db_conn, hs) - self.db_pool = hs.get_db_pool() - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - def r(conn): - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - return func( - LoggingTransaction(txn, desc, self.database_engine, [], []), - *args, - **kwargs - ) - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) - if i < N: - i += 1 - conn.rollback() - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", desc, e) - raise - - with PreserveLoggingContext(): - return (yield self.db_pool.runWithConnection(r)) - def execute(self, f, *args, **kwargs): - return self.runInteraction(f.__name__, f, *args, **kwargs) + return self.db.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.runInteraction("execute_sql", r) + return self.db.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -221,7 +191,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = yield self.postgres_store.db.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -231,12 +201,14 @@ class Porter(object): total_to_port = None if row is None: if table == "sent_transactions": - forward_chunk, already_ported, total_to_port = ( - yield self._setup_sent_transactions() - ) + ( + forward_chunk, + already_ported, + total_to_port, + ) = yield self._setup_sent_transactions() backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -266,7 +238,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -320,7 +292,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -361,7 +333,9 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) + headers, frows, brows = yield self.sqlite_store.db.runInteraction( + "select", r + ) if frows or brows: if frows: @@ -375,7 +349,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -414,7 +388,7 @@ class Porter(object): return headers, rows - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -431,8 +405,8 @@ class Porter(object): rows_dict = [] for row in rows: d = dict(zip(headers, row)) - if "\0" in d['value']: - logger.warning('dropping search row %s', d) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) else: rows_dict.append(d) @@ -452,7 +426,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -502,17 +476,14 @@ class Porter(object): self.progress.set_state("Preparing %s" % config["name"]) conn = self.setup_db(config, engine) - db_pool = adbapi.ConnectionPool( - config["name"], **config["args"] - ) + db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) hs = MockHomeserver(self.hs_config, engine, conn, db_pool) - store = Store(conn, hs) + store = Store(Database(hs), conn, hs) - yield store.runInteraction( - "%s_engine.check_database" % config["name"], - engine.check_database, + yield store.db.runInteraction( + "%s_engine.check_database" % config["name"], engine.check_database, ) return store @@ -520,7 +491,9 @@ class Porter(object): @defer.inlineCallbacks def run_background_updates_on_postgres(self): # Manually apply all background updates on the PostgreSQL database. - postgres_ready = yield self.postgres_store.has_completed_background_updates() + postgres_ready = ( + yield self.postgres_store.db.updates.has_completed_background_updates() + ) if not postgres_ready: # Only say that we're running background updates when there are background @@ -528,9 +501,9 @@ class Porter(object): self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - yield self.postgres_store.do_next_background_update(100) + yield self.postgres_store.db.updates.do_next_background_update(100) postgres_ready = yield ( - self.postgres_store.has_completed_background_updates() + self.postgres_store.db.updates.has_completed_background_updates() ) @defer.inlineCallbacks @@ -539,7 +512,9 @@ class Porter(object): self.sqlite_store = yield self.build_db_store(self.sqlite_config) # Check if all background updates are done, abort if not. - updates_complete = yield self.sqlite_store.has_completed_background_updates() + updates_complete = ( + yield self.sqlite_store.db.updates.has_completed_background_updates() + ) if not updates_complete: sys.stderr.write( "Pending background updates exist in the SQLite3 database." @@ -580,22 +555,22 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction("alter_table", alter_table) + yield self.postgres_store.db.runInteraction("alter_table", alter_table) except Exception: # On Error Resume Next pass - yield self.postgres_store.runInteraction( + yield self.postgres_store.db.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = yield self.sqlite_store.db.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = yield self.postgres_store.db.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -685,11 +660,11 @@ class Porter(object): rows = txn.fetchall() headers = [column[0] for column in txn.description] - ts_ind = headers.index('ts') + ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -722,7 +697,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -735,7 +710,7 @@ class Porter(object): txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - size, = txn.fetchone() + (size,) = txn.fetchone() return int(size) remaining_count = yield self.sqlite_store.execute(get_sent_table_size) @@ -782,10 +757,13 @@ class Porter(object): def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0] + 1 + curr_id = txn.fetchone()[0] + if not curr_id: + return + next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) ############################################## @@ -866,7 +844,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds) + duration_str = "%02dm %02ds" % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -876,7 +854,7 @@ class CursesProgress(Progress): left = float(self.total_remaining) / self.total_processed est_remaining = (int(now) - self.start_time) * left - est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" status = "Time spent: %s (est. remaining: %s)" % ( @@ -962,7 +940,7 @@ if __name__ == "__main__": description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--sqlite-database", required=True, @@ -971,12 +949,12 @@ if __name__ == "__main__": ) parser.add_argument( "--postgres-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', help="display a curses based progress UI" + "--curses", action="store_true", help="display a curses based progress UI" ) parser.add_argument( @@ -1052,3 +1030,4 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) + sys.exit(5) diff --git a/synapse/__init__.py b/synapse/__init__.py index f99de2f3f3ea..c67a51a8d5c0 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ except ImportError: pass -__version__ = "1.6.1" +__version__ = "1.7.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e3f086f1c361..0ade47e62404 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -147,3 +148,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bec13f08d824..6eab1f13f00f 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2ac7d5c064ab..9c96816096a6 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -269,7 +269,7 @@ def handle_sighup(*args, **kwargs): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().start_profiling() + hs.get_datastore().db.start_profiling() setup_sentry(hs) setup_sdnotify(hs) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 448e45e00f0c..f24920a7d615 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,6 +40,7 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -59,8 +60,8 @@ class FederationSenderSlaveStore( SlavedDeviceStore, SlavedPresenceStore, ): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) # We pull out the current federation stream position now so that we # always have a known value for the federation position in memory so diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883b3fb70b8c..032010600ae5 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -68,9 +68,9 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer -from synapse.storage import DataStore, are_all_users_on_domain +from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup, create_engine -from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database +from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -294,22 +294,6 @@ def start_listening(self, listeners): else: logger.warning("Unrecognized listener type: %s", listener["type"]) - def run_startup_checks(self, db_conn, database_engine): - all_users_native = are_all_users_on_domain( - db_conn.cursor(), database_engine, self.hostname - ) - if not all_users_native: - quit_with_error( - "Found users in database not native to %s!\n" - "You cannot changed a synapse server_name after it's been configured" - % (self.hostname,) - ) - - try: - database_engine.check_database(db_conn.cursor()) - except IncorrectDatabaseSetup as e: - quit_with_error(str(e)) - # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") @@ -357,16 +341,12 @@ def setup(config_options): synapse.config.logger.setup_logging(hs, config, use_worker_options=False) - logger.info("Preparing database: %s...", config.database_config["name"]) + logger.info("Setting up server") try: - with hs.get_db_conn(run_new_connection=False) as db_conn: - prepare_database(db_conn, database_engine, config=config) - database_engine.on_new_connection(db_conn) - - hs.run_startup_checks(db_conn, database_engine) - - db_conn.commit() + hs.setup() + except IncorrectDatabaseSetup as e: + quit_with_error(str(e)) except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" @@ -375,9 +355,6 @@ def setup(config_options): ) sys.exit(1) - logger.info("Database prepared in %s.", config.database_config["name"]) - - hs.setup() hs.setup_master() @defer.inlineCallbacks @@ -436,7 +413,7 @@ def start(): _base.start(hs, config.listeners) hs.get_pusherpool().start() - hs.get_datastore().start_doing_background_updates() + hs.get_datastore().db.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) @@ -542,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index b14da09f47fb..288ee64b4265 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -151,7 +151,7 @@ def send_stop_syncing(self): def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? - pass + return defer.succeed(None) get_states = __func__(PresenceHandler.get_states) get_state = __func__(PresenceHandler.get_state) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 6cb100319f25..c01fb34a9bd0 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -43,6 +43,7 @@ from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree @@ -60,11 +61,11 @@ class UserDirectorySlaveStore( UserDirectoryStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c5ea2d43a131..b91414aa3546 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import logging from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import ( - map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, -) -from synapse.util.module_loader import load_python_module +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts @@ -75,15 +77,69 @@ def read_config(self, config, **kwargs): self.saml2_enabled = True - self.saml2_mxid_source_attribute = saml2_config.get( - "mxid_source_attribute", "uid" - ) - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) - saml2_config_dict = self._default_saml_config_dict() + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) _dict_merge( merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict ) @@ -103,22 +159,27 @@ def read_config(self, config, **kwargs): saml2_config.get("saml_session_lifetime", "5m") ) - mapping = saml2_config.get("mxid_mapping", "hexencode") - try: - self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] - except KeyError: - raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) - - def _default_saml_config_dict(self): + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration + + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - required_attributes = {"uid", self.saml2_mxid_source_attribute} - - optional_attributes = {"displayName"} if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes @@ -207,33 +268,58 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #config_path: "%(config_dir_path)s/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace - - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # @@ -241,23 +327,3 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): """ % { "config_dir_path": config_dir_path } - - -DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) -) - - -def dot_replace_for_mxid(username: str) -> str: - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} diff --git a/synapse/config/server.py b/synapse/config/server.py index 7a9d7116692d..a4bef009361a 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -118,15 +118,16 @@ def read_config(self, config, **kwargs): self.allow_public_rooms_without_auth = False self.allow_public_rooms_over_federation = False else: - # If set to 'False', requires authentication to access the server's public - # rooms directory through the client API. Defaults to 'True'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. self.allow_public_rooms_without_auth = config.get( - "allow_public_rooms_without_auth", True + "allow_public_rooms_without_auth", False ) - # If set to 'False', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'True'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. self.allow_public_rooms_over_federation = config.get( - "allow_public_rooms_over_federation", True + "allow_public_rooms_over_federation", False ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -490,6 +491,8 @@ class LimitRemoteRoomsConfig(object): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) @@ -618,15 +621,16 @@ def generate_config_section( # #require_auth_for_profile_requests: true - # If set to 'false', requires authentication to access the server's public rooms - # directory through the client API. Defaults to 'true'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. # - #allow_public_rooms_without_auth: false + #allow_public_rooms_without_auth: true - # If set to 'false', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'true'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. # - #allow_public_rooms_over_federation: false + #allow_public_rooms_over_federation: true # The default room version for newly created rooms. # diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ec3243b27bee..c940b8447080 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) @@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warning("Trusting event: %s", event.event_id) - return - if event.type == EventTypes.Create: sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 27f6aff00440..d396e6564f6c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,8 +18,6 @@ import itertools import logging -from six.moves import range - from prometheus_client import Counter from twisted.internet import defer @@ -39,7 +37,7 @@ ) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache @@ -310,162 +308,26 @@ def get_pdu( return signed_pdu @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + def get_room_state_ids(self, destination: str, room_id: str, event_id: str): + """Calls the /state_ids endpoint to fetch the state at a particular point + in the room, and the auth events for the given event Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) """ - try: - # First we try and ask for just the IDs, as thats far quicker if - # we have most of the state and auth_chain already. - # However, this may 404 if the other side has an old synapse. - result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - state_event_ids = result["pdu_ids"] - auth_event_ids = result.get("auth_chain_ids", []) - - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) - ) - - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - event_map = {ev.event_id: ev for ev in fetched_events} - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [ - event_map[e_id] for e_id in auth_event_ids if e_id in event_map - ] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - except HttpResponseException as e: - if e.code == 400 or e.code == 404: - logger.info("Failed to use get_room_state_ids API, falling back") - else: - raise e - - result = yield self.transport_layer.get_room_state( + result = yield self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id ) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - - pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] - ] - - auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result.get("auth_chain", []) - ] - - seen_events = yield self.store.get_events( - [ev.event_id for ev in itertools.chain(pdus, auth_chain)] - ) - - signed_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, - [p for p in pdus if p.event_id not in seen_events], - outlier=True, - room_version=room_version, - ) - signed_pdus.extend( - seen_events[p.event_id] for p in pdus if p.event_id in seen_events - ) - - signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, - [p for p in auth_chain if p.event_id not in seen_events], - outlier=True, - room_version=room_version, - ) - signed_auth.extend( - seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events - ) - - signed_auth.sort(key=lambda e: e.depth) - - return signed_pdus, signed_auth - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. - - Args: - destination (str) - room_id (str) - event_ids (list) - - Returns: - Deferred: A deferred resolving to a 2-tuple where the first is a list of - events and the second is a list of event ids that we failed to fetch. - """ - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) - - failed_to_fetch = set() - - missing_events = set(event_ids) - for k in seen_events: - missing_events.discard(k) - - if not missing_events: - return signed_events, failed_to_fetch - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) - - batch_size = 20 - missing_events = list(missing_events) - for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i : i + batch_size]) - - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - signed_events.append(result) - batch.discard(result.event_id) + state_event_ids = result["pdu_ids"] + auth_event_ids = result.get("auth_chain_ids", []) - # We removed all events we successfully fetched from `batch` - failed_to_fetch.update(batch) + if not isinstance(state_event_ids, list) or not isinstance( + auth_event_ids, list + ): + raise Exception("invalid response from /state_ids") - return signed_events, failed_to_fetch + return state_event_ids, auth_event_ids @defer.inlineCallbacks @log_function diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index dc95ab2113ef..46dba84cacfc 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -38,30 +38,6 @@ def __init__(self, hs): self.server_name = hs.hostname self.client = hs.get_http_client() - @log_function - def get_room_state(self, destination, room_id, event_id): - """ Requests all state for a given room from the given server at the - given event. - - Args: - destination (str): The host name of the remote homeserver we want - to get the state from. - context (str): The name of the context we want the state of - event_id (str): The event we want the context at. - - Returns: - Deferred: Results in a dict received from the remote homeserver. - """ - logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - - path = _create_v1_path("/state/%s", room_id) - return self.client.get_json( - destination, - path=path, - args={"event_id": event_id}, - try_trailing_slash_on_400=True, - ) - @log_function def get_room_state_ids(self, destination, room_id, event_id): """ Requests all state for a given room from the given server at the diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 6407d56f8e7e..14449b9a1e66 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -56,7 +56,7 @@ def get_whois(self, user): @defer.inlineCallbacks def get_users(self): - """Function to reterive a list of users in users table. + """Function to retrieve a list of users in users table. Args: Returns: @@ -67,19 +67,22 @@ def get_users(self): return ret @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. + def get_users_paginate(self, start, limit, name, guests, deactivated): + """Function to retrieve a paginated list of users from + users list. This will return a json list of users. Args: - order (str): column name to order the select by this column start (int): start number to begin the query from - limit (int): number of rows to reterive + limit (int): number of rows to retrieve + name (string): filter for user names + guests (bool): whether to in include guest users + deactivated (bool): whether to include deactivated users Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} + defer.Deferred: resolves to json list[dict[str, Any]] """ - ret = yield self.store.get_users_paginate(order, start, limit) + ret = yield self.store.get_users_paginate( + start, limit, name, guests, deactivated + ) return ret diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 28c12753c1ac..57a10daefd41 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,7 +264,6 @@ def do_remote_query(destination): return ret - @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -284,35 +283,14 @@ def get_cross_signing_keys_from_cache(self, query, from_user_id): self_signing_keys = {} user_signing_keys = {} - for user_id in query: - # XXX: consider changing the store functions to allow querying - # multiple users simultaneously. - key = yield self.store.get_e2e_cross_signing_key( - user_id, "master", from_user_id - ) - if key: - master_keys[user_id] = key - - key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing", from_user_id - ) - if key: - self_signing_keys[user_id] = key - - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - if from_user_id == user_id: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[user_id] = key - - return { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } + # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 + return defer.succeed( + { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } + ) @trace @defer.inlineCallbacks diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 45fe13c62ff3..ec18a42a68b8 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -16,8 +16,6 @@ import logging import random -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase @@ -50,9 +48,8 @@ def __init__(self, hs): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks @log_function - def get_stream( + async def get_stream( self, auth_user_id, pagin_config, @@ -69,17 +66,17 @@ def get_stream( """ if room_id: - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(auth_user_id) auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() - context = yield presence_handler.user_syncing( + context = await presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence ) with context: @@ -91,7 +88,7 @@ def get_stream( # thundering herds on restart. timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) - events, tokens = yield self.notifier.get_events_for( + events, tokens = await self.notifier.get_events_for( auth_user, pagin_config, timeout, @@ -112,14 +109,14 @@ def get_stream( # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room( + users = await self.state.get_current_users_in_room( event.room_id ) - states = yield presence_handler.get_states(users, as_event=True) + states = await presence_handler.get_states(users, as_event=True) to_add.extend(states) else: - ev = yield presence_handler.get_state( + ev = await presence_handler.get_state( UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -128,7 +125,7 @@ def get_stream( time_now = self.clock.time_msec() - chunks = yield self._event_serializer.serialize_events( + chunks = await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, @@ -151,8 +148,7 @@ def __init__(self, hs): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - @defer.inlineCallbacks - def get_event(self, user, room_id, event_id): + async def get_event(self, user, room_id, event_id): """Retrieve a single specified event. Args: @@ -167,15 +163,15 @@ def get_event(self, user, room_id, event_id): AuthError if the user does not have the rights to inspect this event. """ - event = yield self.store.get_event(event_id, check_room_id=room_id) + event = await self.store.get_event(event_id, check_room_id=room_id) if not event: return None - users = yield self.store.get_users_in_room(event.room_id) + users = await self.store.get_users_in_room(event.room_id) is_peeking = user.to_string() not in users - filtered = yield filter_events_for_client( + filtered = await filter_events_for_client( self.storage, user.to_string(), [event], is_peeking=is_peeking ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 51d95ae1dbb7..ed88fde0fe9c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,11 +19,13 @@ import itertools import logging +from typing import Dict, Iterable, Optional, Sequence, Tuple import six from six import iteritems, itervalues from six.moves import http_client, zip +import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -45,6 +47,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( @@ -62,7 +65,7 @@ from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -73,6 +76,23 @@ logger = logging.getLogger(__name__) +@attr.s +class _NewEventInfo: + """Holds information about a received event, ready for passing to _handle_new_events + + Attributes: + event: the received event + + state: the state at that event + + auth_events: the auth_event map for that event + """ + + event = attr.ib(type=EventBase) + state = attr.ib(type=Optional[Sequence[EventBase]], default=None) + auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) + + def shortstr(iterable, maxitems=5): """If iterable has maxitems or fewer, return the stringification of a list containing those items. @@ -122,6 +142,7 @@ def __init__(self, hs): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -142,6 +163,8 @@ def __init__(self, hs): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -357,11 +380,9 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): ( remote_state, got_auth_chain, - ) = yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ) = yield self._get_state_for_room(origin, room_id, p) - # we want the state *after* p; get_state_for_room returns the + # we want the state *after* p; _get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( [origin], p, room_version, outlier=True @@ -561,6 +582,97 @@ def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): else: raise + @defer.inlineCallbacks + @log_function + def _get_state_for_room(self, destination, room_id, event_id): + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + + Returns: + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. + """ + ( + state_event_ids, + auth_event_ids, + ) = yield self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + desired_events = set(state_event_ids + auth_event_ids) + event_map = yield self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) + + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s: %s", + room_id, + failed_to_fetch, + ) + + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + + auth_chain.sort(key=lambda e: e.depth) + + return pdus, auth_chain + + @defer.inlineCallbacks + def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + """Fetch events from a remote destination, checking if we already have them. + + Args: + destination (str) + room_id (str) + event_ids (Iterable[str]) + + Returns: + Deferred[dict[str, EventBase]]: A deferred resolving to a map + from event_id to event + """ + fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if not missing_events: + return fetched_events + + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + event_ids, + ) + + room_version = yield self.store.get_room_version(room_id) + + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): + deferreds = [ + run_in_background( + self.federation_client.get_pdu, + destinations=[destination], + event_id=e_id, + room_version=room_version, + ) + for e_id in batch + ] + + res = yield make_deferred_yieldable( + defer.DeferredList(deferreds, consumeErrors=True) + ) + for success, result in res: + if success and result: + fetched_events[result.event_id] = result + + return fetched_events + @defer.inlineCallbacks def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it @@ -595,14 +707,14 @@ def _process_received_pdu(self, origin, event, state, auth_chain): for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({"event": e, "auth_events": auth}) + event_infos.append(_NewEventInfo(event=e, auth_events=auth)) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", room_id, event_id, - [e["event"].event_id for e in event_infos], + [e.event.event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) @@ -701,7 +813,7 @@ def backfill(self, dest, room_id, limit, extremities): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.federation_client.get_state_for_room( + state, auth = yield self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) @@ -793,9 +905,9 @@ def backfill(self, dest, room_id, limit, extremities): a.internal_metadata.outlier = True ev_infos.append( - { - "event": a, - "auth_events": { + _NewEventInfo( + event=a, + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -803,7 +915,7 @@ def backfill(self, dest, room_id, limit, extremities): for a_id in a.auth_event_ids() if a_id in auth_events }, - } + ) ) # Step 1b: persist the events in the chunk we fetched state for (i.e. @@ -815,10 +927,10 @@ def backfill(self, dest, room_id, limit, extremities): assert not ev.internal_metadata.is_outlier() ev_infos.append( - { - "event": ev, - "state": events_to_state[e_id], - "auth_events": { + _NewEventInfo( + event=ev, + state=events_to_state[e_id], + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -826,7 +938,7 @@ def backfill(self, dest, room_id, limit, extremities): for a_id in ev.auth_event_ids() if a_id in auth_events }, - } + ) ) yield self._handle_new_events(dest, ev_infos, backfilled=True) @@ -1713,7 +1825,12 @@ def _handle_new_event( return context @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False): + def _handle_new_events( + self, + origin: str, + event_infos: Iterable[_NewEventInfo], + backfilled: bool = False, + ): """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist a bunch of outliers, but not a chunk of individual events that depend @@ -1723,14 +1840,14 @@ def _handle_new_events(self, origin, event_infos, backfilled=False): """ @defer.inlineCallbacks - def prep(ev_info): - event = ev_info["event"] + def prep(ev_info: _NewEventInfo): + event = ev_info.event with nested_logging_context(suffix=event.event_id): res = yield self._prep_event( origin, event, - state=ev_info.get("state"), - auth_events=ev_info.get("auth_events"), + state=ev_info.state, + auth_events=ev_info.auth_events, backfilled=backfilled, ) return res @@ -1744,7 +1861,7 @@ def prep(ev_info): yield self.persist_events_and_notify( [ - (ev_info["event"], context) + (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, @@ -1846,7 +1963,14 @@ def _persist_auth_tree(self, origin, auth_events, state, event): yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks - def _prep_event(self, origin, event, state, auth_events, backfilled): + def _prep_event( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + auth_events: Optional[Dict[Tuple[str, str], EventBase]], + backfilled: bool, + ): """ Args: @@ -1854,7 +1978,7 @@ def _prep_event(self, origin, event, state, auth_events, backfilled): event: state: auth_events: - backfilled (bool) + backfilled: Returns: Deferred, which resolves to synapse.events.snapshot.EventContext @@ -1890,15 +2014,16 @@ def _prep_event(self, origin, event, state, auth_events, backfilled): return context @defer.inlineCallbacks - def _check_for_soft_fail(self, event, state, backfilled): + def _check_for_soft_fail( + self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool + ): """Checks if we should soft fail the event, if so marks the event as such. Args: - event (FrozenEvent) - state (dict|None): The state at the event if we don't have all the - event's prev events - backfilled (bool): Whether the event is from backfill + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill Returns: Deferred @@ -2122,14 +2247,9 @@ def _update_auth_events_and_context_for_auth( # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections(missing_auth) - logger.debug("Found events %s in the store", have_events) - missing_auth.difference_update(have_events.keys()) - else: - have_events = {} - - have_events.update({e.event_id: "" for e in auth_events.values()}) + have_events = yield self.store.have_seen_events(missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) if missing_auth: # If we don't have all the auth events, we need to get them. @@ -2175,9 +2295,6 @@ def _update_auth_events_and_context_for_auth( except AuthError: pass - have_events = yield self.store.get_seen_events_with_rejections( - event.auth_event_ids() - ) except Exception: logger.exception("Failed to get auth chain") @@ -2203,43 +2320,53 @@ def _update_auth_events_and_context_for_auth( different_auth, ) - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = yield self.store.get_events_as_list(different_auth) - room_version = yield self.store.get_room_version(event.room_id) - different_event_ids = [ - d for d in different_auth if d in have_events and not have_events[d] - ] + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) - if different_event_ids: - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = yield self.store.get_events_as_list(different_event_ids) + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context - local_view = dict(auth_events) - remote_view = dict(auth_events) - remote_view.update({(d.type, d.state_key): d for d in different_events}) + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. - new_state = yield self.state_handler.resolve_events( - room_version, - [list(local_view.values()), list(remote_view.values())], - event, - ) + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) + room_version = yield self.store.get_room_version(event.room_id) + new_state = yield self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) - auth_events.update(new_state) + auth_events.update(new_state) - context = yield self._update_context_for_auth_events( - event, context, auth_events - ) + context = yield self._update_context_for_auth_events( + event, context, auth_events + ) return context @@ -2718,6 +2845,11 @@ def persist_events_and_notify(self, event_and_contexts, backfilled=False): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b64..73c110a92b82 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ def __init__(self, hs): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,17 +79,14 @@ def snapshot_all_rooms( as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) @defer.inlineCallbacks diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5b1e90cef345..bf9add7fe2e8 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -63,6 +71,17 @@ def __init__(self, hs): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -226,6 +245,100 @@ def get_joined_members(self, requester, room_id): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -251,6 +364,8 @@ def __init__(self, hs): self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self.room_invite_state_types = self.hs.config.room_invite_state_types + self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users @@ -296,6 +411,10 @@ def __init__(self, hs): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -800,7 +919,7 @@ def is_inviter_member_event(e): state_to_include_ids = [ e_id for k, e_id in iteritems(current_state_ids) - if k[0] in self.hs.config.room_invite_state_types + if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -878,6 +997,10 @@ def is_inviter_member_event(e): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 95806af41e4f..8a7d965feb02 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -266,7 +266,7 @@ def register_user( } # Bind email to new account - yield self._register_email_threepid(user_id, threepid_dict, None, False) + yield self._register_email_threepid(user_id, threepid_dict, None) return user_id diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e92b2eafd568..22768e97ff40 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -198,21 +199,21 @@ def _upgrade_room(self, requester, old_room_id, new_version): # finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state + requester, old_room_id, new_room_id, old_room_state, ) return new_room_id @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state + self, requester, old_room_id, new_room_id, old_room_state, ): """Send updated power levels in both rooms after an upgrade Args: requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_room_id (unicode): the id of the replacement room + old_room_id (str): the id of the room to be replaced + new_room_id (str): the id of the replacement room old_room_state (dict[tuple[str, str], str]): the state map for the old room Returns: @@ -298,7 +299,7 @@ def clone_existing_room( tombstone_event_id (unicode|str): the ID of the tombstone event in the old room. Returns: - Deferred[None] + Deferred """ user_id = requester.user.to_string() @@ -333,6 +334,7 @@ def clone_existing_room( (EventTypes.Encryption, ""), (EventTypes.ServerACL, ""), (EventTypes.RelatedGroups, ""), + (EventTypes.PowerLevels, ""), ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( @@ -346,6 +348,31 @@ def clone_existing_room( if old_event: initial_state[k] = old_event.content + # Resolve the minimum power level required to send any state event + # We will give the upgrading user this power level temporarily (if necessary) such that + # they are able to copy all of the state events over, then revert them back to their + # original power level afterwards in _update_upgraded_room_pls + + # Copy over user power levels now as this will not be possible with >100PL users once + # the room has been created + + power_levels = initial_state[(EventTypes.PowerLevels, "")] + + # Calculate the minimum power level needed to clone the room + event_power_levels = power_levels.get("events", {}) + state_default = power_levels.get("state_default", 0) + ban = power_levels.get("ban") + needed_power_level = max(state_default, ban, max(event_power_levels.values())) + + # Raise the requester's power level in the new room if necessary + current_power_level = power_levels["users"][requester.user.to_string()] + if current_power_level < needed_power_level: + # Assign this power level to the requester + power_levels["users"][requester.user.to_string()] = needed_power_level + + # Set the power levels to the modified state + initial_state[(EventTypes.PowerLevels, "")] = power_levels + yield self._send_events_for_new_room( requester, new_room_id, @@ -874,6 +901,10 @@ def filter_evts(events): room_id, event_id, before_limit, after_limit, event_filter ) + if event_filter: + results["events_before"] = event_filter.filter(results["events_before"]) + results["events_after"] = event_filter.filter(results["events_after"]) + results["events_before"] = yield filter_evts(results["events_before"]) results["events_after"] = yield filter_evts(results["events_after"]) results["event"] = event @@ -902,7 +933,12 @@ def filter_evts(events): state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) - results["state"] = list(state[last_event_id].values()) + + state_events = list(state[last_event_id].values()) + if event_filter: + state_events = event_filter.filter(state_events) + + results["state"] = state_events # We use a dummy token here as we only care about the room portion of # the token, which we replace. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd018..0082f85c2666 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ def __init__(self, hs): self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ async def _map_saml_response_to_user(self, resp_bytes): remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ async def _map_saml_response_to_user(self, resp_bytes): ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ def expire_sessions(self): del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b536d410e519..2d3b8ba73c6a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -22,8 +22,6 @@ from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user @@ -241,8 +239,7 @@ def __init__(self, hs): expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - @defer.inlineCallbacks - def wait_for_sync_for_user( + async def wait_for_sync_for_user( self, sync_config, since_token=None, timeout=0, full_state=False ): """Get the sync for a client if we have new data for it now. Otherwise @@ -255,9 +252,9 @@ def wait_for_sync_for_user( # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) - res = yield self.response_cache.wrap( + res = await self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, sync_config, @@ -267,8 +264,9 @@ def wait_for_sync_for_user( ) return res - @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + async def _wait_for_sync_for_user( + self, sync_config, since_token, timeout, full_state + ): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -283,7 +281,7 @@ def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state) if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user( + result = await self.current_sync_for_user( sync_config, since_token, full_state=full_state ) else: @@ -291,7 +289,7 @@ def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state) def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) - result = yield self.notifier.wait_for_events( + result = await self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, @@ -314,15 +312,13 @@ def current_sync_for_user(self, sync_config, since_token=None, full_state=False) """ return self.generate_sync_result(sync_config, since_token, full_state) - @defer.inlineCallbacks - def push_rules_for_user(self, user): + async def push_rules_for_user(self, user): user_id = user.to_string() - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) return rules - @defer.inlineCallbacks - def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): + async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): """Get the ephemeral events for each room the user is in Args: sync_result_builder(SyncResultBuilder) @@ -343,7 +339,7 @@ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events( + typing, typing_key = await typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -365,7 +361,7 @@ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events( + receipts, receipt_key = await receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -382,8 +378,7 @@ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): return now_token, ephemeral_by_room - @defer.inlineCallbacks - def _load_filtered_recents( + async def _load_filtered_recents( self, room_id, sync_config, @@ -415,10 +410,10 @@ def _load_filtered_recents( # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - recents = yield filter_events_for_client( + recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), recents, @@ -449,14 +444,14 @@ def _load_filtered_recents( # Otherwise, we want to return the last N events in the room # in toplogical ordering. if since_key: - events, end_key = yield self.store.get_room_events_stream_for_room( + events, end_key = await self.store.get_room_events_stream_for_room( room_id, limit=load_limit + 1, from_key=since_key, to_key=end_key, ) else: - events, end_key = yield self.store.get_recent_events_for_room( + events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( @@ -468,10 +463,10 @@ def _load_filtered_recents( # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - loaded_recents = yield filter_events_for_client( + loaded_recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), loaded_recents, @@ -498,8 +493,7 @@ def _load_filtered_recents( limited=limited or newly_joined_room, ) - @defer.inlineCallbacks - def get_state_after_event(self, event, state_filter=StateFilter.all()): + async def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event @@ -511,7 +505,7 @@ def get_state_after_event(self, event, state_filter=StateFilter.all()): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -519,8 +513,9 @@ def get_state_after_event(self, event, state_filter=StateFilter.all()): state_ids[(event.type, event.state_key)] = event.event_id return state_ids - @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): + async def get_state_at( + self, room_id, stream_position, state_filter=StateFilter.all() + ): """ Get the room state at a particular stream position Args: @@ -536,13 +531,13 @@ def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()) # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event( + state = await self.get_state_after_event( last_event, state_filter=state_filter ) @@ -551,8 +546,7 @@ def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()) state = {} return state - @defer.inlineCallbacks - def compute_summary(self, room_id, sync_config, batch, state, now_token): + async def compute_summary(self, room_id, sync_config, batch, state, now_token): """ Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds @@ -574,7 +568,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): # FIXME: we could/should get this from room_stats when matthew/stats lands # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 - last_events, _ = yield self.store.get_recent_event_ids_for_room( + last_events, _ = await self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1 ) @@ -582,7 +576,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): return None last_event = last_events[-1] - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -590,7 +584,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): ) # this is heavily cached, thus: fast. - details = yield self.store.get_room_summary(room_id) + details = await self.store.get_room_summary(room_id) name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -608,12 +602,12 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): # calculating heroes. Empty strings are falsey, so we check # for the "name" value and default to an empty string. if name_id: - name = yield self.store.get_event(name_id, allow_none=True) + name = await self.store.get_event(name_id, allow_none=True) if name and name.content.get("name"): return summary if canonical_alias_id: - canonical_alias = yield self.store.get_event( + canonical_alias = await self.store.get_event( canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): @@ -678,7 +672,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): ) ] - missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = await self.store.get_events(missing_hero_event_ids) missing_hero_state = missing_hero_state.values() for s in missing_hero_state: @@ -697,8 +691,7 @@ def get_lazy_loaded_members_cache(self, cache_key): logger.debug("found LruCache for %r", cache_key) return cache - @defer.inlineCallbacks - def compute_state_delta( + async def compute_state_delta( self, room_id, batch, sync_config, since_token, now_token, full_state ): """ Works out the difference in state between the start of the timeline @@ -759,16 +752,16 @@ def compute_state_delta( if full_state: if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -783,13 +776,13 @@ def compute_state_delta( ) elif batch.limited: if batch: - state_at_timeline_start = yield self.state_store.get_state_ids_for_event( + state_at_timeline_start = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: # We can get here if the user has ignored the senders of all # the recent events. - state_at_timeline_start = yield self.get_state_at( + state_at_timeline_start = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -807,19 +800,19 @@ def compute_state_delta( # about them). state_filter = StateFilter.all() - state_at_previous_sync = yield self.get_state_at( + state_at_previous_sync = await self.get_state_at( room_id, stream_position=since_token, state_filter=state_filter ) if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: # Its not clear how we get here, but empirically we do # (#5407). Logging has been added elsewhere to try and # figure out where this state comes from. - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -843,7 +836,7 @@ def compute_state_delta( # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -883,7 +876,7 @@ def compute_state_delta( state = {} if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) return { (e.type, e.state_key): e @@ -892,10 +885,9 @@ def compute_state_delta( ) } - @defer.inlineCallbacks - def unread_notifs_for_room_id(self, room_id, sync_config): + async def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( + last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, receipt_type="m.read", @@ -903,7 +895,7 @@ def unread_notifs_for_room_id(self, room_id, sync_config): notifs = [] if last_unread_event_id: - notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) return notifs @@ -912,8 +904,9 @@ def unread_notifs_for_room_id(self, room_id, sync_config): # count is whatever it was last time. return None - @defer.inlineCallbacks - def generate_sync_result(self, sync_config, since_token=None, full_state=False): + async def generate_sync_result( + self, sync_config, since_token=None, full_state=False + ): """Generates a sync result. Args: @@ -928,7 +921,7 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = yield self.event_sources.get_current_token() + now_token = await self.event_sources.get_current_token() logger.info( "Calculating sync response for %r between %s and %s", @@ -944,10 +937,9 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - joined_room_ids = yield self.get_rooms_for_user_at( + joined_room_ids = await self.get_rooms_for_user_at( user_id, now_token.room_stream_id ) - sync_result_builder = SyncResultBuilder( sync_config, full_state, @@ -956,11 +948,11 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): joined_room_ids=joined_room_ids, ) - account_data_by_room = yield self._generate_sync_entry_for_account_data( + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) - res = yield self._generate_sync_entry_for_rooms( + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) newly_joined_rooms, newly_joined_or_invited_users, _, _ = res @@ -970,13 +962,13 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: - yield self._generate_sync_entry_for_presence( + await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) - yield self._generate_sync_entry_for_to_device(sync_result_builder) + await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = yield self._generate_sync_entry_for_device_list( + device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_users=newly_joined_or_invited_users, @@ -987,11 +979,11 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): device_id = sync_config.device_id one_time_key_counts = {} if device_id: - one_time_key_counts = yield self.store.count_e2e_one_time_keys( + one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - yield self._generate_sync_entry_for_groups(sync_result_builder) + await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: @@ -1015,18 +1007,17 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): ) @measure_func("_generate_sync_entry_for_groups") - @defer.inlineCallbacks - def _generate_sync_entry_for_groups(self, sync_result_builder): + async def _generate_sync_entry_for_groups(self, sync_result_builder): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = yield self.store.get_groups_changes_for_user( + results = await self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: - results = yield self.store.get_all_groups_for_user( + results = await self.store.get_all_groups_for_user( user_id, now_token.groups_key ) @@ -1059,8 +1050,7 @@ def _generate_sync_entry_for_groups(self, sync_result_builder): ) @measure_func("_generate_sync_entry_for_device_list") - @defer.inlineCallbacks - def _generate_sync_entry_for_device_list( + async def _generate_sync_entry_for_device_list( self, sync_result_builder, newly_joined_rooms, @@ -1108,32 +1098,32 @@ def _generate_sync_entry_for_device_list( # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = yield self.store.get_users_whose_devices_changed( + users_that_have_changed = await self.store.get_users_whose_devices_changed( since_token.device_list_key, users_who_share_room ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = yield self.state.get_current_users_in_room(room_id) + joined_users = await self.state.get_current_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) - user_signatures_changed = yield self.store.get_users_whose_signatures_changed( + user_signatures_changed = await self.store.get_users_whose_signatures_changed( user_id, since_token.device_list_key ) users_that_have_changed.update(user_signatures_changed) # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = yield self.state.get_current_users_in_room(room_id) + left_users = await self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1143,8 +1133,7 @@ def _generate_sync_entry_for_device_list( else: return DeviceLists(changed=[], left=[]) - @defer.inlineCallbacks - def _generate_sync_entry_for_to_device(self, sync_result_builder): + async def _generate_sync_entry_for_to_device(self, sync_result_builder): """Generates the portion of the sync response. Populates `sync_result_builder` with the result. @@ -1165,14 +1154,14 @@ def _generate_sync_entry_for_to_device(self, sync_result_builder): # We only delete messages when a new message comes in, but that's # fine so long as we delete them at some point. - deleted = yield self.store.delete_messages_for_device( + deleted = await self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) logger.debug( "Deleted %d to-device messages up to %d", deleted, since_stream_id ) - messages, stream_id = yield self.store.get_new_messages_for_device( + messages, stream_id = await self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) @@ -1190,8 +1179,7 @@ def _generate_sync_entry_for_to_device(self, sync_result_builder): else: sync_result_builder.to_device = [] - @defer.inlineCallbacks - def _generate_sync_entry_for_account_data(self, sync_result_builder): + async def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. @@ -1209,25 +1197,25 @@ def _generate_sync_entry_for_account_data(self, sync_result_builder): ( account_data, account_data_by_room, - ) = yield self.store.get_updated_account_data_for_user( + ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key ) - push_rules_changed = yield self.store.have_push_rules_changed_for_user( + push_rules_changed = await self.store.have_push_rules_changed_for_user( user_id, int(since_token.push_rules_key) ) if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( account_data, account_data_by_room, - ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) + ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1242,8 +1230,7 @@ def _generate_sync_entry_for_account_data(self, sync_result_builder): return account_data_by_room - @defer.inlineCallbacks - def _generate_sync_entry_for_presence( + async def _generate_sync_entry_for_presence( self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ): """Generates the presence portion of the sync response. Populates the @@ -1271,7 +1258,7 @@ def _generate_sync_entry_for_presence( presence_key = None include_offline = False - presence, presence_key = yield presence_source.get_new_events( + presence, presence_key = await presence_source.get_new_events( user=user, from_key=presence_key, is_guest=sync_config.is_guest, @@ -1283,12 +1270,12 @@ def _generate_sync_entry_for_presence( extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states(extra_users_ids) + states = await self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user @@ -1298,8 +1285,9 @@ def _generate_sync_entry_for_presence( sync_result_builder.presence = presence - @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + async def _generate_sync_entry_for_rooms( + self, sync_result_builder, account_data_by_room + ): """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1321,7 +1309,7 @@ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_ro if block_all_room_ephemeral: ephemeral_by_room = {} else: - now_token, ephemeral_by_room = yield self.ephemeral_by_room( + now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, @@ -1333,16 +1321,16 @@ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_ro since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: - have_changed = yield self._have_rooms_changed(sync_result_builder) + have_changed = await self._have_rooms_changed(sync_result_builder) if not have_changed: - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") return [], [], [], [] - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + ignored_account_data = await self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id ) @@ -1352,18 +1340,18 @@ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_ro ignored_users = frozenset() if since_token: - res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + res = await self._get_rooms_changed(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms, newly_left_rooms = res - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - res = yield self._get_all_rooms(sync_result_builder, ignored_users) + res = await self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res newly_left_rooms = [] - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) def handle_room_entries(room_entry): return self._generate_room_entry( @@ -1376,7 +1364,7 @@ def handle_room_entries(room_entry): always_include=sync_result_builder.full_state, ) - yield concurrently_execute(handle_room_entries, room_entries, 10) + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) @@ -1410,8 +1398,7 @@ def handle_room_entries(room_entry): newly_left_users, ) - @defer.inlineCallbacks - def _have_rooms_changed(self, sync_result_builder): + async def _have_rooms_changed(self, sync_result_builder): """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. """ @@ -1422,7 +1409,7 @@ def _have_rooms_changed(self, sync_result_builder): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1435,8 +1422,7 @@ def _have_rooms_changed(self, sync_result_builder): return True return False - @defer.inlineCallbacks - def _get_rooms_changed(self, sync_result_builder, ignored_users): + async def _get_rooms_changed(self, sync_result_builder, ignored_users): """Gets the the changes that have happened since the last sync. Args: @@ -1461,7 +1447,7 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1499,11 +1485,11 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) @@ -1536,13 +1522,13 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) if old_mem_ev and old_mem_ev.membership == Membership.JOIN: @@ -1566,7 +1552,7 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): if leave_events: leave_event = leave_events[-1] - leave_stream_token = yield self.store.get_stream_token_for_event( + leave_stream_token = await self.store.get_stream_token_for_event( leave_event.event_id ) leave_token = since_token.copy_and_replace( @@ -1603,7 +1589,7 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): timeline_limit = sync_config.filter_collection.timeline_limit() # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, @@ -1652,8 +1638,7 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): return room_entries, invited, newly_joined_rooms, newly_left_rooms - @defer.inlineCallbacks - def _get_all_rooms(self, sync_result_builder, ignored_users): + async def _get_all_rooms(self, sync_result_builder, ignored_users): """Returns entries for all rooms for the user. Args: @@ -1677,7 +1662,7 @@ def _get_all_rooms(self, sync_result_builder, ignored_users): Membership.BAN, ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=membership_list ) @@ -1700,7 +1685,7 @@ def _get_all_rooms(self, sync_result_builder, ignored_users): elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue - invite = yield self.store.get_event(event.event_id) + invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. @@ -1726,8 +1711,7 @@ def _get_all_rooms(self, sync_result_builder, ignored_users): return room_entries, invited, [] - @defer.inlineCallbacks - def _generate_room_entry( + async def _generate_room_entry( self, sync_result_builder, ignored_users, @@ -1769,7 +1753,7 @@ def _generate_room_entry( since_token = room_builder.since_token upto_token = room_builder.upto_token - batch = yield self._load_filtered_recents( + batch = await self._load_filtered_recents( room_id, sync_config, now_token=upto_token, @@ -1796,7 +1780,7 @@ def _generate_room_entry( # tag was added by synapse e.g. for server notice rooms. if full_state: user_id = sync_result_builder.sync_config.user.to_string() - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) # If there aren't any tags, don't send the empty tags list down # sync @@ -1821,7 +1805,7 @@ def _generate_room_entry( ): return - state = yield self.compute_state_delta( + state = await self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, full_state=full_state ) @@ -1844,7 +1828,7 @@ def _generate_room_entry( ) or since_token is None ): - summary = yield self.compute_summary( + summary = await self.compute_summary( room_id, sync_config, batch, state, now_token ) @@ -1861,7 +1845,7 @@ def _generate_room_entry( ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) + notifs = await self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1887,8 +1871,7 @@ def _generate_room_entry( else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - @defer.inlineCallbacks - def get_rooms_for_user_at(self, user_id, stream_ordering): + async def get_rooms_for_user_at(self, user_id, stream_ordering): """Get set of joined rooms for a user at the given stream ordering. The stream ordering *must* be recent, otherwise this may throw an @@ -1903,7 +1886,7 @@ def get_rooms_for_user_at(self, user_id, stream_ordering): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) + joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1921,10 +1904,10 @@ def get_rooms_for_user_at(self, user_id, stream_ordering): logger.info("User joined room after current token: %s", room_id) - extrems = yield self.store.get_forward_extremeties_for_room( + extrems = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering ) - users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 856337b7e2fa..6f784543227d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -313,7 +313,7 @@ def get_new_events(self, from_key, room_ids, **kwargs): events.append(self._make_event_for(room_id)) - return events, handler._latest_room_serial + return defer.succeed((events, handler._latest_room_serial)) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 05fc64f409f3..03934956f49d 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -256,6 +256,7 @@ def writer(r): # transport is the same, just trigger a resumeProducing. if self._producer and r.transport is self._producer.transport: self._producer.resumeProducing() + self._connection_waiter = None return # If the producer is still producing, stop it. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddacba..6747f29e6aff 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 735b882363d5..305b9b017852 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -175,4 +175,4 @@ def run_db_interaction(self, desc, func, *args, **kwargs): Returns: Deferred[object]: result of func """ - return self._store.runInteraction(desc, func, *args, **kwargs) + return self._store.db.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/notifier.py b/synapse/notifier.py index af161a81d781..5f5f765bea5a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -304,8 +304,7 @@ def on_new_replication_data(self): without waking up any of the normal user event streams""" self.notify_replication() - @defer.inlineCallbacks - def wait_for_events( + async def wait_for_events( self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START ): """Wait until the callback returns a non empty response or the @@ -313,9 +312,9 @@ def wait_for_events( """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = yield self.event_sources.get_current_token() + current_token = await self.event_sources.get_current_token() if room_ids is None: - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -344,11 +343,11 @@ def wait_for_events( self.hs.get_reactor(), ) with PreserveLoggingContext(): - yield listener.deferred + await listener.deferred current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) if result: break @@ -364,12 +363,11 @@ def wait_for_events( # This happened if there was no timeout or if the timeout had # already expired. current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) return result - @defer.inlineCallbacks - def get_events_for( + async def get_events_for( self, user, pagination_config, @@ -391,15 +389,14 @@ def get_events_for( """ from_token = pagination_config.from_token if not from_token: - from_token = yield self.event_sources.get_current_token() + from_token = await self.event_sources.get_current_token() limit = pagination_config.limit - room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) + room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - @defer.inlineCallbacks - def check_for_updates(before_token, after_token): + async def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) @@ -415,7 +412,7 @@ def check_for_updates(before_token, after_token): if only_keys and name not in only_keys: continue - new_events, new_key = yield source.get_new_events( + new_events, new_key = await source.get_new_events( user=user, from_key=getattr(from_token, keyname), limit=limit, @@ -425,7 +422,7 @@ def check_for_updates(before_token, after_token): ) if name == "room": - new_events = yield filter_events_for_client( + new_events = await filter_events_for_client( self.storage, user.to_string(), new_events, @@ -461,7 +458,7 @@ def check_for_updates(before_token, after_token): user_id_for_stream, ) - result = yield self.wait_for_events( + result = await self.wait_for_events( user_id_for_stream, timeout, check_for_updates, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1ba7bcd4d881..788178076032 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -386,15 +386,7 @@ def _update_rules_with_member_event_ids( """ sequence = self.sequence - rows = yield self.store._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids.values(), - retcols=("user_id", "membership", "event_id"), - keyvalues={}, - batch_size=500, - desc="_get_rules_for_member_event_ids", - ) + rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 456bc005a005..b91a52824566 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,7 +18,9 @@ import six -from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +36,8 @@ def __func__(inp): class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" @@ -62,7 +64,7 @@ def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": self._cache_id_gen.advance(token) for row in rows: - if row.cache_func == _CURRENT_STATE_CACHE_NAME: + if row.cache_func == CURRENT_STATE_CACHE_NAME: room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bc2f6a12ae89..ebe94909cbd3 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -18,15 +18,16 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id" ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b4f58cea19f4..fbf996e33a10 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -21,8 +22,8 @@ class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 9fb6c5c6ffc1..0c237c6e0f71 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index de50748c3013..dc625e0d7a06 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -18,12 +18,13 @@ from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d0a0eaf75bc0..29f35b9915c2 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -59,13 +60,13 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - super(SlavedEventStore, self).__init__(db_conn, hs) + super(SlavedEventStore, self).__init__(database, db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 5c84ebd12538..bcb068895496 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,13 +14,14 @@ # limitations under the License. from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd2869..69a4ae42f933 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage import DataStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -21,8 +22,8 @@ class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 747ced0c8474..f552e7c97227 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,6 +15,7 @@ from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -22,8 +23,8 @@ class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 3655f05e54f1..eebd5a1fb67b 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,17 +15,18 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.database import Database from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) + super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) def get_push_rules_stream_token(self): return ( diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index b4331d0799e3..f22c2d44a327 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,14 +15,15 @@ # limitations under the License. from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 43d823c60102..d40dc6e1f581 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,14 +30,14 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d9ad386b289f..3a20f4531688 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,14 +14,15 @@ # limitations under the License. from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 68a59a34249a..c122c449f401 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -34,12 +34,12 @@ from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, - GetUsersPaginatedRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, UserRegisterServlet, UsersRestServlet, + UsersRestServletV2, WhoisRestServlet, ) from synapse.util.versionstring import get_version_string @@ -191,6 +191,7 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UsersRestServletV2(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): @@ -201,7 +202,6 @@ def register_servlets_for_client_rest_resource(hs, http_server): PurgeHistoryRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server) - GetUsersPaginatedRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 58a83f93af05..1937879dbe17 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -25,6 +25,7 @@ from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_integer, parse_json_object_from_request, parse_string, @@ -59,71 +60,45 @@ async def on_GET(self, request, user_id): return 200, ret -class GetUsersPaginatedRestServlet(RestServlet): - """Get request to get specific number of users from Synapse. +class UsersRestServletV2(RestServlet): + PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),) + + """Get request to list all local users. This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/users_paginate/ - @admin:user?access_token=admin_access_token&start=0&limit=10 - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - PATTERNS = historical_admin_path_patterns( - "/users_paginate/(?P[^/]*)" - ) + GET /_synapse/admin/v2/users?from=0&limit=10&guests=false + + returns: + 200 OK with list of users if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `user_id` can be used to filter by user id. + The parameter `guests` can be used to exclude guest users. + The parameter `deactivated` can be used to include deactivated users. + """ def __init__(self, hs): - self.store = hs.get_datastore() self.hs = hs self.auth = hs.get_auth() - self.handlers = hs.get_handlers() + self.admin_handler = hs.get_handlers().admin_handler - async def on_GET(self, request, target_user_id): - """Get request to get specific number of users from Synapse. - This needs user to have administrator access in Synapse. - """ + async def on_GET(self, request): await assert_requester_is_admin(self.auth, request) - target_user = UserID.from_string(target_user_id) - - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") - - order = "name" # order by name in user table - start = parse_integer(request, "start", required=True) - limit = parse_integer(request, "limit", required=True) - - logger.info("limit: %s, start: %s", limit, start) - - ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) - return 200, ret + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + user_id = parse_string(request, "user_id", default=None) + guests = parse_boolean(request, "guests", default=True) + deactivated = parse_boolean(request, "deactivated", default=False) - async def on_POST(self, request, target_user_id): - """Post request to get specific number of users from Synapse.. - This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/users_paginate/ - @admin:user?access_token=admin_access_token - JsonBodyToSend: - { - "start": "0", - "limit": "10 - } - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - await assert_requester_is_admin(self.auth, request) - UserID.from_string(target_user_id) - - order = "name" # order by name in user table - params = parse_json_object_from_request(request) - assert_params_in_dict(params, ["limit", "start"]) - limit = params["limit"] - start = params["start"] - logger.info("limit: %s, start: %s", limit, start) + users = await self.admin_handler.get_users_paginate( + start, limit, user_id, guests, deactivated + ) + ret = {"users": users} + if len(users) >= limit: + ret["next_token"] = str(start + len(users)) - ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) return 200, ret diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4ea3666874ee..5934b1fe8bdc 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -47,17 +45,15 @@ def __init__(self, hs): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_alias): + async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) dir_handler = self.handlers.directory_handler - res = yield dir_handler.get_association(room_alias) + res = await dir_handler.get_association(room_alias) return 200, res - @defer.inlineCallbacks - def on_PUT(self, request, room_alias): + async def on_PUT(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) @@ -77,26 +73,25 @@ def on_PUT(self, request, room_alias): # TODO(erikj): Check types. - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Room does not exist") - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.create_association( + await self.handlers.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_alias): + async def on_DELETE(self, request, room_alias): dir_handler = self.handlers.directory_handler try: - service = yield self.auth.get_appservice_by_req(request) + service = await self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association(service, room_alias) + await dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,12 +102,12 @@ def on_DELETE(self, request, room_alias): # fallback to default user behaviour if they aren't an AS pass - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user = requester.user room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association(requester, room_alias) + await dir_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -130,32 +125,29 @@ def __init__(self, hs): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - room = yield self.store.get_room(room_id) + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - @defer.inlineCallbacks - def on_PUT(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, visibility ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -181,15 +173,14 @@ def on_PUT(self, request, network_id, room_id): def on_DELETE(self, request, network_id, room_id): return self._edit(request, network_id, room_id, "private") - @defer.inlineCallbacks - def _edit(self, request, network_id, room_id, visibility): - requester = yield self.auth.get_user_by_req(request) + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( 403, "Only appservices can edit the appservice published room list" ) - yield self.handlers.directory_handler.edit_published_appservice_room_list( + await self.handlers.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 6651b4cf0751..4beb6177333d 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -16,8 +16,6 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -36,9 +34,8 @@ def __init__(self, hs): self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: @@ -57,7 +54,7 @@ def on_GET(self, request): as_client_event = b"raw" not in request.args - chunk = yield self.event_stream_handler.get_stream( + chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -83,14 +80,13 @@ def __init__(self, hs): self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request, event_id): - requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, None, event_id) + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event else: return 404, "Event not found." diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 2da3cd75115c..910b3b4eeb94 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_boolean from synapse.rest.client.v2_alpha._base import client_patterns @@ -29,13 +28,12 @@ def __init__(self, hs): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) - content = yield self.initial_sync_handler.snapshot_all_rooms( + content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 19eb15003d3d..ff9c978fe737 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,7 +18,6 @@ from six.moves import urllib -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError @@ -130,8 +129,7 @@ def on_GET(self, request): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): self._address_ratelimiter.ratelimit( request.getClientIP(), time_now_s=self.hs.clock.time(), @@ -145,11 +143,11 @@ def on_POST(self, request): if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE ): - result = yield self.do_jwt_login(login_submission) + result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = yield self.do_token_login(login_submission) + result = await self.do_token_login(login_submission) else: - result = yield self._do_other_login(login_submission) + result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -158,8 +156,7 @@ def on_POST(self, request): result["well_known"] = well_known_data return 200, result - @defer.inlineCallbacks - def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission): """Handle non-token/saml/jwt logins Args: @@ -219,20 +216,20 @@ def _do_other_login(self, login_submission): ( canonical_user_id, callback_3pid, - ) = yield self.auth_handler.check_password_provider_3pid( + ) = await self.auth_handler.check_password_provider_3pid( medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback_3pid ) return result # No password providers were able to handle this 3pid # Check local store - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastore().get_user_id_by_threepid( medium, address ) if not user_id: @@ -280,7 +277,7 @@ def _do_other_login(self, login_submission): ) try: - canonical_user_id, callback = yield self.auth_handler.validate_login( + canonical_user_id, callback = await self.auth_handler.validate_login( identifier["user"], login_submission ) except LoginError: @@ -297,13 +294,12 @@ def _do_other_login(self, login_submission): ) raise - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback ) return result - @defer.inlineCallbacks - def _complete_login( + async def _complete_login( self, user_id, login_submission, callback=None, create_non_existant_users=False ): """Called when we've successfully authed the user and now need to @@ -337,15 +333,15 @@ def _complete_login( ) if create_non_existant_users: - user_id = yield self.auth_handler.check_user_exists(user_id) + user_id = await self.auth_handler.check_user_exists(user_id) if not user_id: - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( localpart=UserID.from_string(user_id).localpart ) device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name ) @@ -357,23 +353,21 @@ def _complete_login( } if callback is not None: - yield callback(result) + await callback(result) return result - @defer.inlineCallbacks - def do_token_login(self, login_submission): + async def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = yield self._complete_login(user_id, login_submission) + result = await self._complete_login(user_id, login_submission) return result - @defer.inlineCallbacks - def do_jwt_login(self, login_submission): + async def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( @@ -397,7 +391,7 @@ def do_jwt_login(self, login_submission): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - result = yield self._complete_login( + result = await self._complete_login( user_id, login_submission, create_non_existant_users=True ) return result @@ -460,8 +454,7 @@ def __init__(self, hs): self._sso_auth_handler = SSOAuthHandler(hs) self._http_client = hs.get_proxied_http_client() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): client_redirect_url = parse_string(request, "redirectUrl", required=True) uri = self.cas_server_url + "/proxyValidate" args = { @@ -469,12 +462,12 @@ def on_GET(self, request): "service": self.cas_service_url, } try: - body = yield self._http_client.get_raw(uri, args) + body = await self._http_client.get_raw(uri, args) except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data body = pde.response - result = yield self.handle_cas_response(request, body, client_redirect_url) + result = await self.handle_cas_response(request, body, client_redirect_url) return result def handle_cas_response(self, request, cas_response_body, client_redirect_url): @@ -555,8 +548,7 @@ def __init__(self, hs): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() - @defer.inlineCallbacks - def on_successful_auth( + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. @@ -582,9 +574,9 @@ def on_successful_auth( """ localpart = map_username_to_mxid_localpart(username) user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = yield self._auth_handler.check_user_exists(user_id) + registered_user_id = await self._auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id = yield self._registration_handler.register_user( + registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=user_display_name ) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 4785a34d75ab..1cf3caf832f2 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -35,17 +33,16 @@ def __init__(self, hs): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) if requester.device_id is None: # the acccess token wasn't associated with a device. # Just delete the access token access_token = self.auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) + await self._auth_handler.delete_access_token(access_token) else: - yield self._device_handler.delete_device( + await self._device_handler.delete_device( requester.user.to_string(), requester.device_id ) @@ -64,17 +61,16 @@ def __init__(self, hs): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # first delete all of the user's devices - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # .. and then delete any access tokens which weren't associated with # devices. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) return 200, {} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 0153525cefea..eec16f8ad829 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,8 +19,6 @@ from six import string_types -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,27 +38,25 @@ def __init__(self, hs): self.clock = hs.get_clock() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.presence_handler.is_visible( + allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.presence_handler.get_state(target_user=user) + state = await self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) return 200, state - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: @@ -86,7 +82,7 @@ def on_PUT(self, request, user_id): raise SynapseError(400, "Unable to parse state") if self.hs.config.use_presence: - yield self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index bbce2e2b71d9..4f47562c1be0 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -14,7 +14,6 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/ """ -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns @@ -30,19 +29,18 @@ def __init__(self, hs): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) + displayname = await self.profile_handler.get_displayname(user) ret = {} if displayname is not None: @@ -50,11 +48,10 @@ def on_GET(self, request, user_id): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ def on_PUT(self, request, user_id): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) return 200, {} @@ -80,19 +77,18 @@ def __init__(self, hs): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if avatar_url is not None: @@ -100,19 +96,23 @@ def on_GET(self, request, user_id): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: - new_name = content["avatar_url"] + new_avatar_url = content.get("avatar_url") except Exception: - return 400, "Unable to parse name" + return 400, "Unable to parse avatar_url" + + if new_avatar_url is None: + return 400, "Missing required key: avatar_url" - yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) + await self.profile_handler.set_avatar_url( + user, requester, new_avatar_url, is_admin + ) return 200, {} @@ -129,20 +129,19 @@ def __init__(self, hs): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if displayname is not None: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9f8c3d09e3a5..4f746002393c 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import ( NotFoundError, @@ -46,8 +45,7 @@ def __init__(self, hs): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None - @defer.inlineCallbacks - def on_PUT(self, request, path): + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,7 +55,7 @@ def on_PUT(self, request, path): except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") @@ -67,7 +65,7 @@ def on_PUT(self, request, path): user_id = requester.user.to_string() if "attr" in spec: - yield self.set_rule_attr(user_id, spec, content) + await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} @@ -91,7 +89,7 @@ def on_PUT(self, request, path): after = _namespaced_rule_id(spec, after) try: - yield self.store.add_push_rule( + await self.store.add_push_rule( user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, @@ -108,20 +106,19 @@ def on_PUT(self, request, path): return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, path): + async def on_DELETE(self, request, path): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") spec = _rule_spec_from_path([x for x in path.split("/")]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule(user_id, namespaced_rule_id) + await self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) return 200, {} except StoreError as e: @@ -130,15 +127,14 @@ def on_DELETE(self, request, path): else: raise - @defer.inlineCallbacks - def on_GET(self, request, path): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 41660682d97f..0791866f5532 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -39,12 +37,11 @@ def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -78,9 +75,8 @@ def __init__(self, hs): self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user content = parse_json_object_from_request(request) @@ -91,7 +87,7 @@ def on_POST(self, request): and "kind" in content and content["kind"] is None ): - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) return 200, {} @@ -117,14 +113,14 @@ def on_POST(self, request): append = content["append"] if not append: - yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], pushkey=content["pushkey"], not_user_id=user.to_string(), ) try: - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -164,16 +160,15 @@ def __init__(self, hs): self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) try: - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 2afdbb89e56c..747d46eac201 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -17,8 +17,6 @@ import hashlib import hmac -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -31,9 +29,8 @@ def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 8250ae0ae116..2a3f4dd58f28 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -78,7 +78,7 @@ def on_POST(self, request): """ def wrapped(*args, **kwargs): - res = defer.maybeDeferred(orig, *args, **kwargs) + res = defer.ensureDeferred(orig(*args, **kwargs)) res.addErrback(_catch_incomplete_interactive_auth) return res diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ad674239ab4f..fc240f5cf818 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,8 +18,6 @@ from six.moves import http_client -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour @@ -67,8 +65,7 @@ def __init__(self, hs): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -95,7 +92,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -106,7 +103,7 @@ def on_POST(self, request): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -115,7 +112,7 @@ def on_POST(self, request): ) else: # Send password reset emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -153,8 +150,7 @@ def __init__(self, hs): [self.config.email_password_reset_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): # We currently only handle threepid token submissions for email if medium != "email": raise SynapseError( @@ -176,7 +172,7 @@ def on_GET(self, request, medium): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -218,8 +214,7 @@ def __init__(self, hs): self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) # there are two possibilities here. Either the user does not have an @@ -233,14 +228,14 @@ def on_POST(self, request): # In the second case, we require a password to confirm their identity. if self.auth.has_access_token(request): - requester = yield self.auth.get_user_by_req(request) - params = yield self.auth_handler.validate_user_via_ui_auth( + requester = await self.auth.get_user_by_req(request) + params = await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None - result, params, _ = yield self.auth_handler.check_auth( + result, params, _ = await self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) ) @@ -254,7 +249,7 @@ def on_POST(self, request): # (See add_threepid in synapse/handlers/auth.py) threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] ) if not threepid_user_id: @@ -267,7 +262,7 @@ def on_POST(self, request): assert_params_in_dict(params, ["new_password"]) new_password = params["new_password"] - yield self._set_password_handler.set_password(user_id, new_password, requester) + await self._set_password_handler.set_password(user_id, new_password, requester) return 200, {} @@ -286,8 +281,7 @@ def __init__(self, hs): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -297,19 +291,19 @@ def on_POST(self, request): Codes.BAD_JSON, ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) # allow ASes to dectivate their own users if requester.app_service: - yield self._deactivate_account_handler.deactivate_account( + await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) return 200, {} - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - result = yield self._deactivate_account_handler.deactivate_account( + result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: @@ -346,8 +340,7 @@ def __init__(self, hs): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -371,7 +364,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( "email", body["email"] ) @@ -382,7 +375,7 @@ def on_POST(self, request): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -391,7 +384,7 @@ def on_POST(self, request): ) else: # Send threepid validation emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -414,8 +407,7 @@ def __init__(self, hs): self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -435,7 +427,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -450,7 +442,7 @@ def on_POST(self, request): "Adding phone numbers to user account is not supported by this homeserver", ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -484,8 +476,7 @@ def __init__(self, hs): [self.config.email_add_threepid_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -508,7 +499,7 @@ def on_GET(self, request): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -558,8 +549,7 @@ def __init__(self, hs): self.store = hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -571,7 +561,7 @@ def on_POST(self, request): assert_params_in_dict(body, ["client_secret", "sid", "token"]) # Proxy submit_token request to msisdn threepid delegate - response = yield self.identity_handler.proxy_msisdn_submit_token( + response = await self.identity_handler.proxy_msisdn_submit_token( self.config.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], @@ -591,17 +581,15 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -615,11 +603,11 @@ def on_POST(self, request): client_secret = threepid_creds["client_secret"] sid = threepid_creds["sid"] - validation_session = yield self.identity_handler.validate_threepid_session( + validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid ) if validation_session: - yield self.auth_handler.add_threepid( + await self.auth_handler.add_threepid( user_id, validation_session["medium"], validation_session["address"], @@ -643,9 +631,8 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -653,15 +640,15 @@ def on_POST(self, request): client_secret = body["client_secret"] sid = body["sid"] - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - validation_session = yield self.identity_handler.validate_threepid_session( + validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid ) if validation_session: - yield self.auth_handler.add_threepid( + await self.auth_handler.add_threepid( user_id, validation_session["medium"], validation_session["address"], @@ -683,8 +670,7 @@ def __init__(self, hs): self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -693,10 +679,10 @@ def on_POST(self, request): client_secret = body["client_secret"] id_access_token = body.get("id_access_token") # optional - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.identity_handler.bind_threepid( + await self.identity_handler.bind_threepid( client_secret, sid, user_id, id_server, id_access_token ) @@ -713,12 +699,11 @@ def __init__(self, hs): self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) @@ -728,7 +713,7 @@ def on_POST(self, request): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past - result = yield self.identity_handler.try_unbind_threepid( + result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), {"address": address, "medium": medium, "id_server": id_server}, ) @@ -743,16 +728,15 @@ def __init__(self, hs): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: - ret = yield self.auth_handler.delete_threepid( + ret = await self.auth_handler.delete_threepid( user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: @@ -777,9 +761,8 @@ def __init__(self, hs): super(WhoamiRestServlet, self).__init__() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) return 200, {"user_id": requester.user.to_string()} diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f0db204ffa20..64eb7fec3bfb 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -41,15 +39,14 @@ def __init__(self, hs): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_account_data_for_user( + max_id = await self.store.add_account_data_for_user( user_id, account_data_type, body ) @@ -57,13 +54,12 @@ def on_PUT(self, request, user_id, account_data_type): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_global_account_data_by_type_for_user( + event = await self.store.get_global_account_data_by_type_for_user( account_data_type, user_id ) @@ -91,9 +87,8 @@ def __init__(self, hs): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -106,7 +101,7 @@ def on_PUT(self, request, user_id, room_id, account_data_type): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -114,13 +109,12 @@ def on_PUT(self, request, user_id, room_id, account_data_type): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_account_data_for_room_and_type( + event = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type ) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 33f6a23028d2..2f10fa64e2e0 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet @@ -45,13 +43,12 @@ def __init__(self, hs): self.success_html = hs.config.account_validity.account_renewed_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = yield self.account_activity_handler.renew_account( + token_valid = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) @@ -67,7 +64,6 @@ def on_GET(self, request): request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.write(response.encode("utf8")) finish_request(request) - defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -85,18 +81,17 @@ def __init__(self, hs): self.auth = hs.get_auth() self.account_validity = self.hs.config.account_validity - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) - requester = yield self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() - yield self.account_activity_handler.send_renewal_email_to_user(user_id) + await self.account_activity_handler.send_renewal_email_to_user(user_id) - defer.returnValue((200, {})) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index f21aff39e596..7a256b6ecb13 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX @@ -171,8 +169,7 @@ def on_GET(self, request, stagetype): else: raise SynapseError(404, "Unknown auth stage type") - @defer.inlineCallbacks - def on_POST(self, request, stagetype): + async def on_POST(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -186,7 +183,7 @@ def on_POST(self, request, stagetype): authdict = {"response": response, "session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) @@ -215,7 +212,7 @@ def on_POST(self, request, stagetype): session = request.args["session"][0] authdict = {"session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index acd58af193a6..fe9d019c442b 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet @@ -40,10 +38,9 @@ def __init__(self, hs): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user = yield self.store.get_user_by_id(requester.user.to_string()) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = await self.store.get_user_by_id(requester.user.to_string()) change_password = bool(user["password_hash"]) response = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 26d02352087b..94ff73f384e1 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api import errors from synapse.http.servlet import ( RestServlet, @@ -42,10 +40,9 @@ def __init__(self, hs): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - devices = yield self.device_handler.get_devices_by_user( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) return 200, {"devices": devices} @@ -67,9 +64,8 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -84,11 +80,11 @@ def on_POST(self, request): assert_params_in_dict(body, ["devices"]) - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_devices( + await self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) return 200, {} @@ -108,18 +104,16 @@ def __init__(self, hs): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - device = yield self.device_handler.get_device( + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( requester.user.to_string(), device_id ) return 200, device @interactive_auth_handler - @defer.inlineCallbacks - def on_DELETE(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -132,19 +126,18 @@ def on_DELETE(self, request, device_id): else: raise - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device(requester.user.to_string(), device_id) + await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) - yield self.device_handler.update_device( + await self.device_handler.update_device( requester.user.to_string(), device_id, body ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 17a8bc7366fd..b28da017cd51 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -35,10 +33,9 @@ def __init__(self, hs): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_GET(self, request, user_id, filter_id): + async def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot get filters for other users") @@ -52,7 +49,7 @@ def on_GET(self, request, user_id, filter_id): raise SynapseError(400, "Invalid filter_id") try: - filter_collection = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user_localpart=target_user.localpart, filter_id=filter_id ) except StoreError as e: @@ -72,11 +69,10 @@ def __init__(self, hs): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_POST(self, request, user_id): + async def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot create filters for other users") @@ -87,7 +83,7 @@ def on_POST(self, request, user_id): content = parse_json_object_from_request(request) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - filter_id = yield self.filtering.add_user_filter( + filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content ) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 999a0fa80c9c..d84a6d7e1108 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -38,24 +36,22 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile( + group_description = await self.groups_handler.get_group_profile( group_id, requester_user_id ) return 200, group_description - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - yield self.groups_handler.update_group_profile( + await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -74,12 +70,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary( + get_group_summary = await self.groups_handler.get_group_summary( group_id, requester_user_id ) @@ -106,13 +101,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_room( + resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, room_id=room_id, @@ -122,12 +116,11 @@ def on_PUT(self, request, group_id, category_id, room_id): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_room( + resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -148,35 +141,32 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_category( + category = await self.groups_handler.get_group_category( group_id, requester_user_id, category_id=category_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_category( + resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_category( + resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -195,12 +185,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_categories( + category = await self.groups_handler.get_group_categories( group_id, requester_user_id ) @@ -219,35 +208,32 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_role( + category = await self.groups_handler.get_group_role( group_id, requester_user_id, role_id=role_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_role( + resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_role( + resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -266,12 +252,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_roles( + category = await self.groups_handler.get_group_roles( group_id, requester_user_id ) @@ -298,13 +283,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_user( + resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, user_id=user_id, @@ -314,12 +298,11 @@ def on_PUT(self, request, group_id, role_id, user_id): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_user( + resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -338,12 +321,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group( + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -362,12 +344,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group( + result = await self.groups_handler.get_users_in_group( group_id, requester_user_id ) @@ -386,12 +367,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group( + result = await self.groups_handler.get_invited_users_in_group( group_id, requester_user_id ) @@ -409,14 +389,13 @@ def __init__(self, hs): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.set_group_join_policy( + result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -436,9 +415,8 @@ def __init__(self, hs): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() # TODO: Create group on remote server @@ -446,7 +424,7 @@ def on_POST(self, request): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group( + result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -467,24 +445,22 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( + result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) return 200, result - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( + result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -506,13 +482,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id, config_key): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id, config_key): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.update_room_in_group( + result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -535,14 +510,13 @@ def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) - result = yield self.groups_handler.invite( + result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -563,13 +537,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -588,13 +561,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -613,13 +585,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.join_group( + result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -638,13 +609,12 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.accept_invite( + result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -663,14 +633,13 @@ def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity(group_id, requester_user_id, publicise) + await self.store.update_group_publicity(group_id, requester_user_id, publicise) return 200, {} @@ -688,11 +657,10 @@ def __init__(self, hs): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, user_id): + await self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user(user_id) + result = await self.groups_handler.get_publicised_groups_for_user(user_id) return 200, result @@ -710,14 +678,13 @@ def __init__(self, hs): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) return 200, result @@ -734,12 +701,11 @@ def __init__(self, hs): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(requester_user_id) + result = await self.groups_handler.get_joined_groups(requester_user_id) return 200, result diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 341567ae2139..f7ed4daf90a7 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -71,9 +69,8 @@ def __init__(self, hs): self.e2e_keys_handler = hs.get_e2e_keys_handler() @trace(opname="upload_keys") - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -103,7 +100,7 @@ def on_POST(self, request, device_id): 400, "To upload keys, you must pass device_id when authenticating" ) - result = yield self.e2e_keys_handler.upload_keys_for_user( + result = await self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) return 200, result @@ -154,13 +151,12 @@ def __init__(self, hs): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) + result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) return 200, result @@ -185,9 +181,8 @@ def __init__(self, hs): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from") set_tag("from", from_token_string) @@ -200,7 +195,7 @@ def on_GET(self, request): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed(user_id, from_token) + results = await self.device_handler.get_user_ids_changed(user_id, from_token) return 200, results @@ -231,12 +226,11 @@ def __init__(self, hs): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) return 200, result @@ -263,17 +257,16 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result @@ -315,13 +308,12 @@ def __init__(self, hs): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( user_id, body ) return 200, result diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 10c1ad5b07b8..aa911d75ee89 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -35,9 +33,8 @@ def __init__(self, hs): self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() from_token = parse_string(request, "from", required=False) @@ -46,16 +43,16 @@ def on_GET(self, request): limit = min(limit, 500) - push_actions = yield self.store.get_push_actions_for_user( + push_actions = await self.store.get_push_actions_for_user( user_id, from_token, limit, only_highlight=(only == "highlight") ) - receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = yield self.store.get_events(notif_event_ids) + notif_events = await self.store.get_events(notif_event_ids) returned_push_actions = [] @@ -68,7 +65,7 @@ def on_GET(self, request): "actions": pa["actions"], "ts": pa["received_ts"], "event": ( - yield self._event_serializer.serialize_event( + await self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index b4925c0f5933..6ae9a5a8e9fe 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string @@ -68,9 +66,8 @@ def __init__(self, hs): self.clock = hs.get_clock() self.server_name = hs.config.server_name - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -81,7 +78,7 @@ def on_POST(self, request, user_id): token = random_string(24) ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) return ( 200, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 91db92381402..66de16a1fa8d 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -20,8 +20,6 @@ from six import string_types -from twisted.internet import defer - import synapse import synapse.types from synapse.api.constants import LoginType @@ -102,8 +100,7 @@ def __init__(self, hs): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -129,7 +126,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", body["email"] ) @@ -140,7 +137,7 @@ def on_POST(self, request): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -149,7 +146,7 @@ def on_POST(self, request): ) else: # Send registration emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -175,8 +172,7 @@ def __init__(self, hs): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( @@ -197,7 +193,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) @@ -215,7 +211,7 @@ def on_POST(self, request): 400, "Registration by phone number is not supported on this homeserver" ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -258,8 +254,7 @@ def __init__(self, hs): [self.config.email_registration_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -280,7 +275,7 @@ def on_GET(self, request, medium): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -338,8 +333,7 @@ def __init__(self, hs): ), ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -347,11 +341,11 @@ def on_GET(self, request): ip = self.hs.get_ip_from_request(request) with self.ratelimiter.ratelimit(ip) as wait_deferred: - yield wait_deferred + await wait_deferred username = parse_string(request, "username", required=True) - yield self.registration_handler.check_username(username) + await self.registration_handler.check_username(username) return 200, {"available": True} @@ -382,8 +376,7 @@ def __init__(self, hs): ) @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) client_addr = request.getClientIP() @@ -408,7 +401,7 @@ def on_POST(self, request): kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body, address=client_addr) + ret = await self._do_guest_registration(body, address=client_addr) return ret elif kind != b"user": raise UnrecognizedRequestError( @@ -435,7 +428,7 @@ def on_POST(self, request): appservice = None if self.auth.has_access_token(request): - appservice = yield self.auth.get_appservice_by_req(request) + appservice = await self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -455,7 +448,7 @@ def on_POST(self, request): access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): - result = yield self._do_appservice_registration( + result = await self._do_appservice_registration( desired_username, access_token, body ) return 200, result # we throw for non 200 responses @@ -495,13 +488,13 @@ def on_POST(self, request): ) if desired_username is not None: - yield self.registration_handler.check_username( + await self.registration_handler.check_username( desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, ) - auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = await self.auth_handler.check_auth( self._registration_flows, body, self.hs.get_ip_from_request(request) ) @@ -557,7 +550,7 @@ def on_POST(self, request): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( medium, address ) @@ -568,7 +561,7 @@ def on_POST(self, request): Codes.THREEPID_IN_USE, ) - registered_user_id = yield self.registration_handler.register_user( + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, @@ -581,7 +574,7 @@ def on_POST(self, request): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(registered_user_id) + await self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) @@ -591,12 +584,12 @@ def on_POST(self, request): registered = True - return_dict = yield self._create_registration_details( + return_dict = await self._create_registration_details( registered_user_id, params ) if registered: - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, access_token=return_dict.get("access_token"), @@ -607,15 +600,13 @@ def on_POST(self, request): def on_OPTIONS(self, _): return 200, {} - @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token, body): - user_id = yield self.registration_handler.appservice_register( + async def _do_appservice_registration(self, username, as_token, body): + user_id = await self.registration_handler.appservice_register( username, as_token ) - return (yield self._create_registration_details(user_id, body)) + return await self._create_registration_details(user_id, body) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, params): + async def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -631,18 +622,17 @@ def _create_registration_details(self, user_id, params): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False ) result.update({"access_token": access_token, "device_id": device_id}) return result - @defer.inlineCallbacks - def _do_guest_registration(self, params, address=None): + async def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( make_guest=True, address=address ) @@ -650,7 +640,7 @@ def _do_guest_registration(self, params, address=None): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 040b37c50436..9be9a34b9141 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -21,8 +21,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -86,11 +84,10 @@ def on_PUT(self, request, *args, **kwargs): request, self.on_PUT_or_POST, request, *args, **kwargs ) - @defer.inlineCallbacks - def on_PUT_or_POST( + async def on_PUT_or_POST( self, request, room_id, parent_id, relation_type, event_type, txn_id=None ): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: # Add relations to a membership is meaningless, so we just deny it @@ -114,7 +111,7 @@ def on_PUT_or_POST( "sender": requester.user.to_string(), } - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) @@ -140,17 +137,18 @@ def __init__(self, hs): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -167,7 +165,7 @@ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=Non if to_token: to_token = RelationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_relations_for_event( + pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -176,7 +174,7 @@ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=Non to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] ) @@ -184,13 +182,13 @@ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=Non # We set bundle_aggregations to False when retrieving the original # event because we want the content before relations were applied to # it. - original_event = yield self._event_serializer.serialize_event( + original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) @@ -232,17 +230,18 @@ def __init__(self, hs): self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This checks that a) the event exists and b) the user is allowed to # view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -262,7 +261,7 @@ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=Non if to_token: to_token = AggregationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_aggregation_groups_for_event( + pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, limit=limit, @@ -311,17 +310,16 @@ def __init__(self, hs): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -336,7 +334,7 @@ def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): if to_token: to_token = RelationPaginationToken.from_string(to_token) - result = yield self.store.get_relations_for_event( + result = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -346,12 +344,12 @@ def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in result.chunk] ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = events diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e7449864cddf..f067b5edac56 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -18,8 +18,6 @@ from six import string_types from six.moves import http_client -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, @@ -42,9 +40,8 @@ def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ def on_POST(self, request, room_id, event_id): Codes.BAD_JSON, ) - yield self.store.add_event_report( + await self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d83ac8e3c57a..38952a1d276a 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -43,8 +41,7 @@ def __init__(self, hs): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, session_id): + async def on_PUT(self, request, room_id, session_id): """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -123,7 +120,7 @@ def on_PUT(self, request, room_id, session_id): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) version = parse_string(request, "version") @@ -134,11 +131,10 @@ def on_PUT(self, request, room_id, session_id): if room_id: body = {"rooms": {room_id: body}} - ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) return 200, ret - @defer.inlineCallbacks - def on_GET(self, request, room_id, session_id): + async def on_GET(self, request, room_id, session_id): """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -190,11 +186,11 @@ def on_GET(self, request, room_id, session_id): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - room_keys = yield self.e2e_room_keys_handler.get_room_keys( + room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id ) @@ -220,8 +216,7 @@ def on_GET(self, request, room_id, session_id): return 200, room_keys - @defer.inlineCallbacks - def on_DELETE(self, request, room_id, session_id): + async def on_DELETE(self, request, room_id, session_id): """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -235,11 +230,11 @@ def on_DELETE(self, request, room_id, session_id): the version must already have been created via the /change_secret API. """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - ret = yield self.e2e_room_keys_handler.delete_room_keys( + ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) return 200, ret @@ -257,8 +252,7 @@ def __init__(self, hs): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -288,11 +282,11 @@ def on_POST(self, request): "version": 12345 } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) return 200, {"version": new_version} # we deliberately don't have a PUT /version, as these things really should @@ -311,8 +305,7 @@ def __init__(self, hs): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_GET(self, request, version): + async def on_GET(self, request, version): """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -330,18 +323,17 @@ def on_GET(self, request, version): "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - @defer.inlineCallbacks - def on_DELETE(self, request, version): + async def on_DELETE(self, request, version): """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -354,14 +346,13 @@ def on_DELETE(self, request, version): if version is None: raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version(user_id, version) + await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, version): + async def on_PUT(self, request, version): """ Update the information about a given version of the user's room_keys backup. @@ -382,7 +373,7 @@ def on_PUT(self, request, version): Content-Type: application/json {} """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) @@ -391,7 +382,7 @@ def on_PUT(self, request, version): 400, "No version specified to update", Codes.MISSING_PARAM ) - yield self.e2e_room_keys_handler.update_version(user_id, version, info) + await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index d2c3316eb7f9..ca97330797ca 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( @@ -59,9 +57,8 @@ def __init__(self, hs): self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self._auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) @@ -74,7 +71,7 @@ def on_POST(self, request, room_id): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = yield self._room_creation_handler.upgrade_room( + new_room_id = await self._room_creation_handler.upgrade_room( requester, room_id, new_version ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d90e52ed1ab2..501b52fb6c5a 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request from synapse.logging.opentracing import set_tag, trace @@ -51,15 +49,14 @@ def on_PUT(self, request, message_type, txn_id): request, self._put, request, message_type, txn_id ) - @defer.inlineCallbacks - def _put(self, request, message_type, txn_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) sender_user_id = requester.user.to_string() - yield self.device_message_handler.send_device_message( + await self.device_message_handler.send_device_message( sender_user_id, message_type, content["messages"] ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index ccd8b17b23a5..d8292ce29fc6 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -18,8 +18,6 @@ from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -87,8 +85,7 @@ def __init__(self, hs): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. @@ -96,7 +93,7 @@ def on_GET(self, request): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id @@ -138,7 +135,7 @@ def on_GET(self, request): filter_collection = FilterCollection(filter_object) else: try: - filter_collection = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user.localpart, filter_id ) except StoreError as err: @@ -161,20 +158,20 @@ def on_GET(self, request): since_token = None # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(user.to_string()) + await self._server_notices_sender.on_user_syncing(user.to_string()) affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state( + await self.presence_handler.set_state( user, {"presence": set_presence}, True ) - context = yield self.presence_handler.user_syncing( + context = await self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence ) with context: - sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_result = await self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, @@ -182,14 +179,13 @@ def on_GET(self, request): ) time_now = self.clock.time_msec() - response_content = yield self.encode_response( + response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) return 200, response_content - @defer.inlineCallbacks - def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -197,7 +193,7 @@ def encode_response(self, time_now, sync_result, access_token_id, filter): else: raise Exception("Unknown event format %s" % (filter.event_format,)) - joined = yield self.encode_joined( + joined = await self.encode_joined( sync_result.joined, time_now, access_token_id, @@ -205,11 +201,11 @@ def encode_response(self, time_now, sync_result, access_token_id, filter): event_formatter, ) - invited = yield self.encode_invited( + invited = await self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter ) - archived = yield self.encode_archived( + archived = await self.encode_archived( sync_result.archived, time_now, access_token_id, @@ -250,8 +246,9 @@ def encode_presence(events, time_now): ] } - @defer.inlineCallbacks - def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the joined rooms in a sync result @@ -272,7 +269,7 @@ def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -283,8 +280,7 @@ def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter return joined - @defer.inlineCallbacks - def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -304,7 +300,7 @@ def encode_invited(self, rooms, time_now, token_id, event_formatter): """ invited = {} for room in rooms: - invite = yield self._event_serializer.serialize_event( + invite = await self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -319,8 +315,9 @@ def encode_invited(self, rooms, time_now, token_id, event_formatter): return invited - @defer.inlineCallbacks - def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the archived rooms in a sync result @@ -341,7 +338,7 @@ def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatt """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -352,8 +349,7 @@ def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatt return joined - @defer.inlineCallbacks - def encode_room( + async def encode_room( self, room, time_now, token_id, joined, only_fields, event_formatter ): """ @@ -401,8 +397,8 @@ def serialize(events): event.room_id, ) - serialized_state = yield serialize(state_events) - serialized_timeline = yield serialize(timeline_events) + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) account_data = room.account_data diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 3b555669a0cc..a3f12e8a774d 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -37,13 +35,12 @@ def __init__(self, hs): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) return 200, {"tags": tags} @@ -64,27 +61,25 @@ def __init__(self, hs): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) + max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 2e8d6724717a..23709960ad85 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet @@ -35,11 +33,10 @@ def __init__(self, hs): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols() + protocols = await self.appservice_handler.get_3pe_protocols() return 200, protocols @@ -52,11 +49,10 @@ def __init__(self, hs): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols( + protocols = await self.appservice_handler.get_3pe_protocols( only_protocol=protocol ) if protocol in protocols: @@ -74,14 +70,13 @@ def __init__(self, hs): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields ) @@ -97,14 +92,13 @@ def __init__(self, hs): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 2da0f5581126..83f3b6b70ad7 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet @@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 2863affbaba6..bef91a2d3ed1 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -38,8 +36,7 @@ def __init__(self, hs): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Searches for users in directory Returns: @@ -56,7 +53,7 @@ def on_POST(self, request): ] } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: @@ -72,7 +69,7 @@ def on_POST(self, request): except Exception: raise SynapseError(400, "`search_term` is required field") - results = yield self.user_directory_handler.search_users( + results = await self.user_directory_handler.search_users( user_id, search_term, limit ) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bb30ce3f3437..2a477ad22e0e 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb0d02aa83db..6b978be8764b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -402,7 +402,7 @@ def _expire_url_cache_data(self): logger.info("Running url preview cache expiry") - if not (yield self.store.has_completed_background_updates()): + if not (yield self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 8cf415e29dee..c234ea74212f 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -129,5 +129,8 @@ def crop(self, width, height, output_type): def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() - output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) + fmt = self.FORMATS[output_type] + if fmt == "JPEG": + output_image = output_image.convert("RGB") + output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io diff --git a/synapse/server.py b/synapse/server.py index be9af7f986ef..2db3dab22115 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -238,8 +238,7 @@ def __init__(self, hostname, reactor=None, **kwargs): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(datastore, conn, self) + self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) conn.commit() self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0460fe8cc9b0..ec89f645d401 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -17,10 +17,10 @@ """ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple -databases). The `data_stores` are classes that talk directly to a single -database and have associated schemas, background updates, etc. On top of those -there are (or will be) classes that provide high level interfaces that combine -calls to multiple `data_stores`. +databases). The `Database` class represents a single physical database. The +`data_stores` are classes that talk directly to a `Database` instance and have +associated schemas, background updates, etc. On top of those there are classes +that provide high level interfaces that combine calls to multiple `data_stores`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are @@ -49,15 +49,3 @@ def __init__(self, hs, stores: DataStores): self.persistence = EventsPersistenceStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores) self.state = StateGroupStorage(hs, stores) - - -def are_all_users_on_domain(txn, database_engine, domain): - sql = database_engine.convert_param_style( - "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" - ) - pat = "%:" + domain - txn.execute(sql, (pat,)) - num_not_matching = txn.fetchall()[0][0] - if num_not_matching == 0: - return True - return False diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 459901ac60a4..b7637b5dc0e8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,1433 +14,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging import random -import sys -import threading -import time -from typing import Iterable, Tuple -from six import PY2, iteritems, iterkeys, itervalues -from six.moves import builtins, intern, range +from six import PY2 +from six.moves import builtins from canonicaljson import json -from prometheus_client import Histogram -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.database import LoggingTransaction # noqa: F401 +from synapse.storage.database import make_in_list_sql_clause # noqa: F401 +from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util import batch_iter -from synapse.util.caches.descriptors import Cache -from synapse.util.stringutils import exception_to_unicode - -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 - -sql_logger = logging.getLogger("synapse.storage.SQL") -transaction_logger = logging.getLogger("synapse.storage.txn") -perf_logger = logging.getLogger("synapse.storage.TIME") - -sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") - -sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) -sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) - - -# Unique indexes which have been added in background updates. Maps from table name -# to the name of the background update which added the unique index to that table. -# -# This is used by the upsert logic to figure out which tables are safe to do a proper -# UPSERT on: until the relevant background update has completed, we -# have to emulate an upsert by locking the table. -# -UNIQUE_INDEX_BACKGROUND_UPDATES = { - "user_ips": "user_ips_device_unique_index", - "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", - "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", - "event_search": "event_search_event_id_idx", -} - -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - -class LoggingTransaction(object): - """An object that almost-transparently proxies for the 'txn' object - passed to the constructor. Adds logging and metrics to the .execute() - method. +class SQLBaseStore(object): + """Base class for data stores that holds helper functions. - Args: - txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to - that have been added by `call_after` which should be run on - successful completion of the transaction. None indicates that no - callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended - to that have been added by `call_on_exception` which should be run - if transaction ends with an error. None indicates that no callbacks - should be allowed to be scheduled to run. + Note that multiple instances of this class will exist as there will be one + per data store (and not one per physical database). """ - __slots__ = [ - "txn", - "name", - "database_engine", - "after_callbacks", - "exception_callbacks", - ] - - def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None - ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) - - def call_after(self, callback, *args, **kwargs): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. - """ - self.after_callbacks.append((callback, args, kwargs)) - - def call_on_exception(self, callback, *args, **kwargs): - self.exception_callbacks.append((callback, args, kwargs)) - - def __getattr__(self, name): - return getattr(self.txn, name) - - def __setattr__(self, name, value): - setattr(self.txn, name, value) - - def __iter__(self): - return self.txn.__iter__() - - def execute_batch(self, sql, args): - if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch - - self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) - else: - for val in args: - self.execute(sql, val) - - def execute(self, sql, *args): - self._do_execute(self.txn.execute, sql, *args) - - def executemany(self, sql, *args): - self._do_execute(self.txn.executemany, sql, *args) - - def _make_sql_one_line(self, sql): - "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) - - def _do_execute(self, func, sql, *args): - sql = self._make_sql_one_line(sql) - - # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] {%s} %s", self.name, sql) - - sql = self.database_engine.convert_param_style(sql) - if args: - try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) - except Exception: - # Don't let logging failures stop SQL from working - pass - - start = time.time() - - try: - return func(sql, *args) - except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) - raise - finally: - secs = time.time() - start - sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) - - -class PerformanceCounters(object): - def __init__(self): - self.current_counters = {} - self.previous_counters = {} - - def update(self, key, duration_secs): - count, cum_time = self.current_counters.get(key, (0, 0)) - count += 1 - cum_time += duration_secs - self.current_counters[key] = (count, cum_time) - - def interval(self, interval_duration_secs, limit=3): - counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): - prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append( - ( - (cum_time - prev_time) / interval_duration_secs, - count - prev_count, - name, - ) - ) - - self.previous_counters = dict(self.current_counters) - - counters.sort(reverse=True) - - top_n_counters = ", ".join( - "%s(%d): %.3f%%" % (name, count, 100 * ratio) - for ratio, count, name in counters[:limit] - ) - - return top_n_counters - - -class SQLBaseStore(object): - _TXN_ID = 0 - - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() - - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 - - # TODO(paul): These can eventually be removed once the metrics code - # is running in mainline, and we have some nice monitoring frontends - # to watch it - self._txn_perf_counters = PerformanceCounters() - - self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size - ) - - self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] - self._event_fetch_ongoing = 0 - - self._pending_ds = [] - self.database_engine = hs.database_engine - - # A set of tables that are not safe to use native upserts in. - self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - - self._account_validity = self.hs.config.account_validity - - # We add the user_directory_search table to the blacklist on SQLite - # because the existing search table does not have an index, making it - # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): - self._unsafe_to_upsert_tables.add("user_directory_search") - - if self.database_engine.can_native_upsert: - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates of tables that aren't safe to update. - self._clock.call_later( - 0.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - + self.db = database self.rand = random.SystemRandom() - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - @defer.inlineCallbacks - def _check_safe_to_upsert(self): - """ - Is it safe to use native UPSERT? - - If there are background updates, we will need to wait, as they may be - the addition of indexes that set the UNIQUE constraint that we require. - - If the background updates have not completed, wait 15 sec and check again. - """ - updates = yield self._simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", - ) - updates = [x["update_name"] for x in updates] - - for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: - logger.debug("Now safe to upsert in %s", table) - self._unsafe_to_upsert_tables.discard(table) - - # If there's any updates still running, reschedule to run. - if updates: - self._clock.call_later( - 15.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - yield self.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self._simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - - def start_profiling(self): - self._previous_loop_ts = monotonic_time() - - def loop(): - curr = self._current_txn_total_time - prev = self._previous_txn_total_time - self._previous_txn_total_time = curr - - time_now = monotonic_time() - time_then = self._previous_loop_ts - self._previous_loop_ts = time_now - - duration = time_now - time_then - ratio = (curr - prev) / duration - - top_three_counters = self._txn_perf_counters.interval(duration, limit=3) - - perf_logger.info( - "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters - ) - - self._clock.looping_call(loop, 10000) - - def _new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): - start = monotonic_time() - txn_id = self._TXN_ID - - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - - name = "%s-%x" % (desc, txn_id) - - transaction_logger.debug("[TXN START] {%s}", name) - - try: - i = 0 - N = 5 - while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - try: - r = func(cursor, *args, **kwargs) - conn.commit() - return r - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) - continue - raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), - ) - continue - raise - finally: - # we're either about to retry with a new cursor, or we're about to - # release the connection. Once we release the connection, it could - # get used for another query, which might do a conn.rollback(). - # - # In the latter case, even though that probably wouldn't affect the - # results of this transaction, python's sqlite will reset all - # statements on the connection [1], which will make our cursor - # invalid [2]. - # - # In any case, continuing to read rows after commit()ing seems - # dubious from the PoV of ACID transactional semantics - # (sqlite explicitly says that once you commit, you may see rows - # from subsequent updates.) - # - # In psycopg2, cursors are essentially a client-side fabrication - - # all the data is transferred to the client side when the statement - # finishes executing - so in theory we could go on streaming results - # from the cursor, but attempting to do so would make us - # incompatible with sqlite, so let's make sure we're not doing that - # by closing the cursor. - # - # (*named* cursors in psycopg2 are different and are proper server- - # side things, but (a) we don't use them and (b) they are implicitly - # closed by ending the transaction anyway.) - # - # In short, if we haven't finished with the cursor yet, that's a - # problem waiting to bite us. - # - # TL;DR: we're done with the cursor, so we can close it. - # - # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 - # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 - cursor.close() - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = monotonic_time() - duration = end - start - - LoggingContext.current_context().add_database_transaction(duration) - - transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) - - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, duration) - sql_txn_timer.labels(desc).observe(duration) - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - """Starts a transaction on the database and runs a given function - - Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a - database transaction (twisted.enterprise.adbapi.Transaction) as - its first argument, followed by `args` and `kwargs`. - - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - after_callbacks = [] - exception_callbacks = [] - - if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warning("Starting db txn '%s' from sentinel context", desc) - - try: - result = yield self.runWithConnection( - self._new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs - ) - - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except: # noqa: E722, as we reraise the exception this is fine. - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise - - return result - - @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runWithConnection() method on the underlying db_pool. - - Arguments: - func (func): callback function, which will be called with a - database connection (twisted.enterprise.adbapi.Connection) as - its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - parent_context = LoggingContext.current_context() - if parent_context == LoggingContext.sentinel: - logger.warning( - "Starting db connection from sentinel context: metrics will be lost" - ) - parent_context = None - - start_time = monotonic_time() - - def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection", parent_context) as context: - sched_duration_sec = monotonic_time() - start_time - sql_scheduling_timer.observe(sched_duration_sec) - context.add_database_scheduled(sched_duration_sec) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - return func(conn, *args, **kwargs) - - result = yield make_deferred_yieldable( - self._db_pool.runWithConnection(inner_func, *args, **kwargs) - ) - - return result - - @staticmethod - def cursor_to_dict(cursor): - """Converts a SQL cursor into an list of dicts. - - Args: - cursor : The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) - return results - - def _execute(self, desc, decoder, query, *args): - """Runs a single query for a result set. - - Args: - decoder - The function which can resolve the cursor results to - something meaningful. - query - The query string to execute - *args - Query args. - Returns: - The result of decoder(results) - """ - - def interaction(txn): - txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() - - return self.runInteraction(desc, interaction) - - # "Simple" SQL API methods that operate on a single table with no JOINs, - # no complex WHERE clauses, just a dict of values for columns. - - @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): - """Executes an INSERT query on the named table. - - Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead - desc : string giving a description of the transaction - - Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True - """ - try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True - - @staticmethod - def _simple_insert_txn(txn, table, values): - keys, vals = zip(*values.items()) - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys), - ", ".join("?" for _ in keys), - ) - - txn.execute(sql, vals) - - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) - - @staticmethod - def _simple_insert_many_txn(txn, table, values): - if not values: - return - - # This is a *slight* abomination to get a list of tuples of key names - # and a list of tuples of value names. - # - # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] - # - # The sort is to ensure that we don't rely on dictionary iteration - # order. - keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] - ) - - for k in keys: - if k != keys[0]: - raise RuntimeError("All items must have the same keys") - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) - - txn.executemany(sql, vals) - - @defer.inlineCallbacks - def _simple_upsert( - self, - table, - keyvalues, - values, - insertion_values={}, - desc="_simple_upsert", - lock=True, - ): - """ - - `lock` should generally be set to True (the default), but can be set - to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - Deferred(None or bool): Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - attempts = 0 - while True: - try: - result = yield self.runInteraction( - desc, - self._simple_upsert_txn, - table, - keyvalues, - values, - insertion_values, - lock=lock, - ) - return result - except self.database_engine.module.IntegrityError as e: - attempts += 1 - if attempts >= 5: - # don't retry forever, because things other than races - # can cause IntegrityErrors - raise - - # presumably we raced with another transaction: let's retry. - logger.warning( - "IntegrityError when upserting into %s; retrying: %s", table, e - ) - - def _simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Pick the UPSERT method which works best on the platform. Either the - native one (Pg9.5+, recent SQLites), or fall back to an emulated method. - - Args: - txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self._simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values - ) - else: - return self._simple_upsert_txn_emulated( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, - lock=lock, - ) - - def _simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - bool: Return True if a new entry was created, False if an existing - one was updated. - """ - # We need to lock the table :(, unless we're *really* careful - if lock: - self.database_engine.lock_table(txn, table) - - def _getwhere(key): - # If the value we're passing in is None (aka NULL), we need to use - # IS, not =, as NULL = NULL equals NULL (False). - if keyvalues[key] is None: - return "%s IS ?" % (key,) - else: - return "%s = ?" % (key,) - - if not values: - # If `values` is empty, then all of the values we care about are in - # the unique key, so there is nothing to UPDATE. We can just do a - # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(keyvalues.values()) - txn.execute(sql, sqlargs) - if txn.fetchall(): - # We have an existing record. - return False - else: - # First try to update. - sql = "UPDATE %s SET %s WHERE %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(values.values()) + list(keyvalues.values()) - - txn.execute(sql, sqlargs) - if txn.rowcount > 0: - # successfully updated at least one row. - return False - - # We didn't find any existing rows, so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) - - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ) - txn.execute(sql, list(allvalues.values())) - # successfully inserted - return True - - def _simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): - """ - Use the native UPSERT functionality in recent PostgreSQL versions. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None - """ - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(insertion_values) - - if not values: - latter = "NOTHING" - else: - allvalues.update(values) - latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - - sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ", ".join(k for k in keyvalues), - latter, - ) - txn.execute(sql, list(allvalues.values())) - - def _simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self._simple_upsert_many_txn_native_upsert( - txn, table, key_names, key_values, value_names, value_values - ) - else: - return self._simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values - ) - - def _simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, but without native UPSERT support or batching. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] - - for keyv, valv in zip(key_values, value_values): - _keys = {x: y for x, y in zip(key_names, keyv)} - _vals = {x: y for x, y in zip(value_names, valv)} - - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) - - def _simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, using batching where possible. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - allnames = [] - allnames.extend(key_names) - allnames.extend(value_names) - - if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. - latter = "NOTHING" - value_values = [() for x in range(len(key_values))] - else: - latter = "UPDATE SET " + ", ".join( - k + "=EXCLUDED." + k for k in value_names - ) - - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - - args = [] - - for x, y in zip(key_values, value_values): - args.append(tuple(x) + tuple(y)) - - return txn.execute_batch(sql, args) - - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning multiple columns from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows - """ - return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none - ) - - def _simple_select_one_onecol( - self, - table, - keyvalues, - retcol, - allow_none=False, - desc="_simple_select_one_onecol", - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return - """ - return self.runInteraction( - desc, - self._simple_select_one_onecol_txn, - table, - keyvalues, - retcol, - allow_none=allow_none, - ) - - @classmethod - def _simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): - ret = cls._simple_select_onecol_txn( - txn, table=table, keyvalues=keyvalues, retcol=retcol - ) - - if ret: - return ret[0] - else: - if allow_none: - return None - else: - raise StoreError(404, "No row found") - - @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} - - if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - txn.execute(sql, list(keyvalues.values())) - else: - txn.execute(sql) - - return [r[0] for r in txn] - - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" - ): - """Executes a SELECT query on the named table, which returns a list - comprising of the values of the named column from the selected rows. - - Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. - - Returns: - Deferred: Results in a list - """ - return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol - ) - - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols - ) - - @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - """ - if keyvalues: - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - txn.execute(sql, list(keyvalues.values())) - else: - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - txn.execute(sql) - - return cls.cursor_to_dict(txn) - - @defer.inlineCallbacks - def _simple_select_many_batch( - self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="_simple_select_many_batch", - batch_size=100, - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - results = [] - - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: - rows = yield self.runInteraction( - desc, - self._simple_select_many_txn, - table, - column, - chunk, - keyvalues, - retcols, - ) - - results.extend(rows) - - return results - - @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - if not iterable: - return [] - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join(clauses), - ) - - txn.execute(sql, values) - return cls.cursor_to_dict(txn) - - def _simple_update(self, table, keyvalues, updatevalues, desc): - return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues - ) - - @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) - - return txn.rowcount - - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" - ): - """Executes an UPDATE query on the named table, setting new values for - columns in a row matching the key values. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. - """ - return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues - ) - - @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - - if rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(select_sql, list(keyvalues.values())) - row = txn.fetchone() - - if not row: - if allow_none: - return None - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - return dict(zip(retcols, row)) - - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) - - @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - if txn.rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) - - @staticmethod - def _simple_delete_txn(txn, table, keyvalues): - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - return txn.rowcount - - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): - return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues - ) - - @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): - """Executes a DELETE query on the named table. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - - Returns: - int: Number rows deleted - """ - if not iterable: - return 0 - - sql = "DELETE FROM %s" % table - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - txn.execute(sql, values) - - return txn.rowcount - - def _get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - - cache = {row[0]: int(row[1]) for row in txn} - - txn.close() - - if cache: - min_val = min(itervalues(cache)) - else: - min_val = max_value - - return cache, min_val - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] - ) - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1474,226 +77,6 @@ def _attempt_to_invalidate_cache(self, cache_name, key): # which is fine. pass - def _send_invalidation_to_replication(self, txn, cache_name, keys): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name (str) - keys (iterable[str]) - """ - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - self._simple_insert_txn( - txn, - table="cache_invalidation_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_name, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def _simple_select_list_paginate( - self, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - desc="_simple_select_list_paginate", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self._simple_select_list_paginate_txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction=order_direction, - ) - - @classmethod - def _simple_select_list_paginate_txn( - cls, - txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - if order_direction not in ["ASC", "DESC"]: - raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") - - if keyvalues: - where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) - else: - where_clause = "" - - sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( - ", ".join(retcols), - table, - where_clause, - orderby, - order_direction, - ) - txn.execute(sql, list(keyvalues.values()) + [limit, start]) - - return cls.cursor_to_dict(txn) - - def get_user_count_txn(self, txn): - """Get a total number of registered users in the users list. - - Args: - txn : Transaction object - Returns: - int : number of users - """ - sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" - txn.execute(sql_count) - return txn.fetchone()[0] - - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - - return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols - ) - - @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) - termvalues = ["%%" + term + "%%"] - txn.execute(sql, termvalues) - else: - return 0 - - return cls.cursor_to_dict(txn) - - @property - def database_engine_name(self): - return self.database_engine.module.__name__ - - def get_server_version(self): - """Returns a string describing the server version number""" - return self.database_engine.server_version - - -class _RollbackButIsFineException(Exception): - """ This exception is used to rollback a transaction without implying - something went wrong. - """ - - pass - def db_to_json(db_content): """ @@ -1722,30 +105,3 @@ def db_to_json(db_content): except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise - - -def make_in_list_sql_clause( - database_engine, column: str, iterable: Iterable -) -> Tuple[str, Iterable]: - """Returns an SQL clause that checks the given column is in the iterable. - - On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres - it expands to `column = ANY(?)`. While both DBs support the `IN` form, - using the `ANY` form on postgres means that it views queries with - different length iterables as the same, helping the query stats. - - Args: - database_engine - column: Name of the column - iterable: The values to check the column against. - - Returns: - A tuple of SQL query and the args - """ - - if database_engine.supports_using_any_list: - # This should hopefully be faster, but also makes postgres query - # stats easier to understand. - return "%s = ANY(?)" % (column,), [list(iterable)] - else: - return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 37d469ffd703..4f97fd5ab6a2 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -22,7 +22,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from . import engines -from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ def total_items_per_ms(self): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdateStore(SQLBaseStore): +class BackgroundUpdater(object): """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, db_conn, hs): - super(BackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, hs, database): + self._clock = hs.get_clock() + self.db = database + self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} @@ -101,9 +102,7 @@ def run_background_updates(self, sleep=True): logger.info("Starting background schema updates") while True: if sleep: - yield self.hs.get_clock().sleep( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 - ) + yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: result = yield self.do_next_background_update( @@ -139,7 +138,7 @@ def has_completed_background_updates(self): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self._simple_select_onecol( + updates = yield self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +160,7 @@ async def has_completed_background_update(self, update_name) -> bool: if update_name in self._background_update_queue: return False - update_exists = await self._simple_select_one_onecol( + update_exists = await self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +183,7 @@ def do_next_background_update(self, desired_duration_ms): no more work to do. """ if not self._background_update_queue: - updates = yield self._simple_select_list( + updates = yield self.db.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +225,7 @@ def _do_background_update(self, update_name, desired_duration_ms): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self._simple_select_one_onecol( + progress_json = yield self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -380,7 +379,7 @@ def create_index_sqlite(conn): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.database_engine, engines.PostgresEngine): + if isinstance(self.db.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None @@ -391,7 +390,7 @@ def create_index_sqlite(conn): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.runWithConnection(runner) + yield self.db.runWithConnection(runner) yield self._end_background_update(update_name) return 1 @@ -413,7 +412,7 @@ def start_background_update(self, update_name, progress): self._background_update_queue = [] progress_json = json.dumps(progress) - return self._simple_insert( + return self.db.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +428,7 @@ def _end_background_update(self, update_name): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self._simple_delete_one( + return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +443,7 @@ def _background_update_progress_txn(self, txn, update_name, progress): progress_json = json.dumps(progress) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cb184a98cc93..cafedd5c0da5 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.database import Database +from synapse.storage.prepare_database import prepare_database + class DataStores(object): """The various data stores. @@ -20,7 +23,14 @@ class DataStores(object): These are low level interfaces to physical databases. """ - def __init__(self, main_store, db_conn, hs): - # Note we pass in the main store here as workers use a different main + def __init__(self, main_store_class, db_conn, hs): + # Note we pass in the main store class here as workers use a different main # store. - self.main = main_store + database = Database(hs) + + # Check that db is correctly configured. + database.engine.check_database(db_conn.cursor()) + + prepare_database(db_conn, database.engine, config=hs.config) + + self.main = main_store_class(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 10c940df1e2f..c577c0df5ffb 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -19,9 +19,8 @@ import logging import time -from twisted.internet import defer - from synapse.api.constants import PresenceState +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, @@ -32,6 +31,7 @@ from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .cache import CacheInvalidationStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -110,11 +110,22 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, + CacheInvalidationStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine + + all_users_native = are_all_users_on_domain( + db_conn.cursor(), database.engine, hs.hostname + ) + if not all_users_native: + raise Exception( + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured" + % (hs.hostname,) + ) self._stream_id_gen = StreamIdGenerator( db_conn, @@ -169,9 +180,11 @@ def __init__(self, db_conn, hs): else: self._cache_id_gen = None + super(DataStore, self).__init__(database, db_conn, hs) + self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self._get_cache_dict( + presence_cache_prefill, min_presence_val = self.db.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -185,7 +198,7 @@ def __init__(self, db_conn, hs): ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -200,7 +213,7 @@ def __init__(self, db_conn, hs): ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -226,7 +239,7 @@ def __init__(self, db_conn, hs): ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -240,7 +253,7 @@ def __init__(self, db_conn, hs): prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -260,8 +273,6 @@ def __init__(self, db_conn, hs): # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() - super(DataStore, self).__init__(db_conn, hs) - def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None @@ -281,7 +292,7 @@ def _get_active_presence(self, db_conn): txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) txn.close() for row in rows: @@ -294,7 +305,7 @@ def count_daily_users(self): Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) + return self.db.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -304,7 +315,7 @@ def count_monthly_users(self): amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( + return self.db.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -404,7 +415,7 @@ def _count_r30_users(txn): return results - return self.runInteraction("count_r30_users", _count_r30_users) + return self.db.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -469,50 +480,73 @@ def _generate_user_daily_visits(txn): # frequently self._last_user_visit_update = now - return self.runInteraction( + return self.db.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) def get_users(self): - """Function to reterive a list of users in users table. + """Function to retrieve a list of users in users table. Args: Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_select_list( + return self.db.simple_select_list( table="users", keyvalues={}, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "user_type", + "deactivated", + ], desc="get_users", ) - @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. + def get_users_paginate( + self, start, limit, name=None, guests=True, deactivated=False + ): + """Function to retrieve a paginated list of users from + users list. This will return a json list of users. Args: - order (str): column name to order the select by this column start (int): start number to begin the query from - limit (int): number of rows to reterive + limit (int): number of rows to retrieve + name (string): filter for user names + guests (bool): whether to in include guest users + deactivated (bool): whether to include deactivated users Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} + defer.Deferred: resolves to list[dict[str, Any]] """ - users = yield self.runInteraction( - "get_users_paginate", - self._simple_select_list_paginate_txn, + name_filter = {} + if name: + name_filter["name"] = "%" + name + "%" + + attr_filter = {} + if not guests: + attr_filter["is_guest"] = False + if not deactivated: + attr_filter["deactivated"] = False + + return self.db.simple_select_list_paginate( + desc="get_users_paginate", table="users", - keyvalues={"is_guest": False}, - orderby=order, + orderby="name", start=start, limit=limit, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + filters=name_filter, + keyvalues=attr_filter, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "user_type", + "deactivated", + ], ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) - retval = {"users": users, "total": count} - return retval def search_users(self, term): """Function to search users list for one or more users with @@ -524,10 +558,22 @@ def search_users(self, term): Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_search_list( + return self.db.simple_search_list( table="users", term=term, col="name", retcols=["name", "password_hash", "is_guest", "admin", "user_type"], desc="search_users", ) + + +def are_all_users_on_domain(txn, database_engine, domain): + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + txn.execute(sql, (pat,)) + num_not_matching = txn.fetchall()[0][0] + if num_not_matching == 0: + return True + return False diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 22093484eda4..46b494b33465 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,6 +22,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -67,7 +68,7 @@ def get_account_data_for_user(self, user_id): """ def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +79,7 @@ def get_account_data_for_user_txn(txn): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -92,7 +93,7 @@ def get_account_data_for_user_txn(txn): return global_account_data, by_room - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -102,7 +103,7 @@ def get_global_account_data_by_type_for_user(self, data_type, user_id): Returns: Deferred: A dict """ - result = yield self._simple_select_one_onecol( + result = yield self.db.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +128,7 @@ def get_account_data_for_room(self, user_id, room_id): """ def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -138,7 +139,7 @@ def get_account_data_for_room_txn(txn): row["account_data_type"]: json.loads(row["content"]) for row in rows } - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @@ -156,7 +157,7 @@ def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type """ def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( + content_json = self.db.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -170,7 +171,7 @@ def get_account_data_for_room_and_type_txn(txn): return json.loads(content_json) if content_json else None - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -207,7 +208,7 @@ def get_updated_account_data_txn(txn): room_results = txn.fetchall() return global_results, room_results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) @@ -250,9 +251,9 @@ def get_updated_account_data_for_user_txn(txn): user_id, int(stream_id) ) if not changed: - return {}, {} + return defer.succeed(({}, {})) - return self.runInteraction( + return self.db.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -270,12 +271,12 @@ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream @@ -300,9 +301,9 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will + # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self._simple_upsert( + yield self.db.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -346,9 +347,9 @@ def add_account_data_for_user(self, user_id, account_data_type, content): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if + # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self._simple_upsert( + yield self.db.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -388,4 +389,4 @@ def _update(txn): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.runInteraction("update_account_data_max_stream_id", _update) + return self.db.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 81babf2029ad..b2f39649fdbe 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -24,6 +24,7 @@ from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache @@ -133,7 +134,7 @@ def get_appservices_by_state(self, state): A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self._simple_select_list( + results = yield self.db.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +156,7 @@ def get_appservice_state(self, service): Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self._simple_select_one( + result = yield self.db.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +176,7 @@ def set_appservice_state(self, service, state): Returns: A Deferred which resolves when the state was set successfully. """ - return self._simple_upsert( + return self.db.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -216,7 +217,7 @@ def _create_appservice_txn(txn): ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.runInteraction("create_appservice_txn", _create_appservice_txn) + return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -249,7 +250,7 @@ def _complete_appservice_txn(txn): ) # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,11 +258,13 @@ def _complete_appservice_txn(txn): ) # Delete txn - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) - return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) + return self.db.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): @@ -283,7 +286,7 @@ def _get_oldest_unsent_txn(txn): " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None @@ -291,7 +294,7 @@ def _get_oldest_unsent_txn(txn): return entry - entry = yield self.runInteraction( + entry = yield self.db.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -321,7 +324,7 @@ def set_appservice_last_pos_txn(txn): "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.runInteraction( + return self.db.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) @@ -350,7 +353,7 @@ def get_new_events_for_appservice_txn(txn): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py new file mode 100644 index 000000000000..54ed8574c48e --- /dev/null +++ b/synapse/storage/data_stores/main/cache.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine +from synapse.util import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationStore(SQLBaseStore): + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication(self, txn, cache_name, keys): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name (str) + keys (iterable[str]) + """ + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_on_exception(ctx.__exit__, None, None, None) + txn.call_after(ctx.__exit__, None, None, None) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + self.db.simple_insert_txn( + txn, + table="cache_invalidation_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_name, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 706c6a1f3fac..add3037b6957 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -20,9 +20,10 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage import background_updates -from synapse.storage._base import Cache +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -32,41 +33,41 @@ LAST_SEEN_GRANULARITY = 120 * 1000 -class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) +class ClientIpBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -75,12 +76,12 @@ def __init__(self, db_conn, hs): ) # Drop the old non-unique index - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) @@ -91,8 +92,8 @@ def f(conn): txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") + yield self.db.runWithConnection(f) + yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") return 1 @defer.inlineCallbacks @@ -106,9 +107,9 @@ def _analyze_user_ip(self, progress, batch_size): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.runInteraction("user_ips_analyze", user_ips_analyze) + yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - yield self._end_background_update("user_ips_analyze") + yield self.db.updates._end_background_update("user_ips_analyze") return 1 @@ -140,7 +141,7 @@ def get_last_seen(txn): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.runInteraction( + end_last_seen = yield self.db.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -271,14 +272,14 @@ def remove(txn): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.runInteraction("user_ips_dups_remove", remove) + yield self.db.runInteraction("user_ips_dups_remove", remove) if last: - yield self._end_background_update("user_ips_remove_dupes") + yield self.db.updates._end_background_update("user_ips_remove_dupes") return batch_size @@ -344,7 +345,7 @@ def _devices_last_seen_update_txn(txn): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -352,24 +353,24 @@ def _devices_last_seen_update_txn(txn): return len(rows) - updated = yield self.runInteraction( + updated = yield self.db.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) if not updated: - yield self._end_background_update("devices_last_seen") + yield self.db.updates._end_background_update("devices_last_seen") return updated class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) - super(ClientIpStore, self).__init__(db_conn, hs) + super(ClientIpStore, self).__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age @@ -417,12 +418,12 @@ def _update_client_ips_batch(self): to_update = self._batch_row_update self._batch_row_update = {} - return self.runInteraction( + return self.db.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self._unsafe_to_upsert_tables or ( + if "user_ips" in self.db._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") @@ -431,7 +432,7 @@ def _update_client_ips_batch_txn(self, txn, to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,16 +451,18 @@ def _update_client_ips_batch_txn(self, txn, to_update): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self._simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue @@ -483,7 +486,7 @@ def get_last_client_ip_by_device(self, user_id, device_id): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self._simple_select_list( + res = yield self.db.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +519,7 @@ def get_user_ip_and_agents(self, user): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -546,7 +549,9 @@ async def _prune_old_user_ips(self): # Nothing to do return - if not await self.has_completed_background_update("devices_last_seen"): + if not await self.db.updates.has_completed_background_update( + "devices_last_seen" + ): # Only start pruning if we have finished populating the devices # last seen info. return @@ -577,4 +582,4 @@ async def _prune_old_user_ips(self): def _prune_old_user_ips_txn(txn): txn.execute(sql, (timestamp,)) - await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) + await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index a23744f11c10..85cfa1685058 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ def get_new_messages_for_device_txn(txn): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @@ -109,7 +109,7 @@ def delete_messages_for_device_txn(txn): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.runInteraction( + count = yield self.db.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -178,7 +178,7 @@ def get_new_messages_for_remote_destination_txn(txn): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -203,25 +203,25 @@ def delete_messages_for_remote_destination_txn(txn): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.runInteraction( + return self.db.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) -class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) @@ -232,9 +232,9 @@ def reindex_txn(conn): txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -242,8 +242,8 @@ def reindex_txn(conn): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. @@ -294,7 +294,7 @@ def add_messages_txn(txn, now_ms, stream_id): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -314,7 +314,7 @@ def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( + already_inserted = self.db.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ def add_messages_txn(txn, now_ms, stream_id): # Add an entry for this message_id so that we know we've processed # it. - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -344,7 +344,7 @@ def add_messages_txn(txn, now_ms, stream_id): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -465,6 +465,6 @@ def get_all_new_device_messages_txn(txn): return rows - return self.runInteraction( + return self.db.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 71f62036c095..9a828231c46e 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -30,16 +30,16 @@ whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - Cache, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) logger = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def get_device(self, user_id, device_id): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + return self.db.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ def get_devices_by_user(self, user_id): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self._simple_select_list( + devices = yield self.db.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -122,7 +122,7 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): # consider the device update to be too large, and simply skip the # stream_id; the rationale being that such a large device list update # is likely an error. - updates = yield self.runInteraction( + updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -283,7 +283,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m """ devices = ( - yield self.runInteraction( + yield self.db.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -340,12 +340,12 @@ def f(txn): rows = txn.fetchall() return rows[0][0] - return self.runInteraction("get_last_device_update_for_remote_user", f) + return self.db.runInteraction("get_last_device_update_for_remote_user", f) def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. """ - return self.runInteraction( + return self.db.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -399,7 +399,7 @@ def add_user_signature_change_to_streams(self, from_user_id, user_ids): """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -414,7 +414,7 @@ def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id) from_user_id, stream_id, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ def get_user_devices_from_cache(self, query_list): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( + content = yield self.db.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ def _get_cached_user_device(self, user_id, device_id): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( + devices = yield self.db.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -492,7 +492,7 @@ def get_devices_with_keys_by_user(self, user_id): Returns: (stream_id, devices) """ - return self.runInteraction( + return self.db.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, @@ -565,7 +565,7 @@ def _get_users_whose_devices_changed_txn(txn): return changes - return self.runInteraction( + return self.db.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) @@ -584,7 +584,7 @@ def get_users_whose_signatures_changed(self, user_id, from_key): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self._execute( + rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ def get_all_device_list_changes_for_remotes(self, from_key, to_key): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self._execute( + return self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ def get_device_list_last_stream_id_for_remote(self, user_id): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ def get_device_list_last_stream_id_for_remote(self, user_id): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -642,11 +642,11 @@ def get_device_list_last_stream_id_for_remotes(self, user_ids): return results -class DeviceBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) +class DeviceBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -654,7 +654,7 @@ def __init__(self, db_conn, hs): ) # create a unique index on device_lists_remote_cache - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -663,7 +663,7 @@ def __init__(self, db_conn, hs): ) # And one on device_lists_remote_extremeties - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -672,7 +672,7 @@ def __init__(self, db_conn, hs): ) # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) @@ -685,14 +685,16 @@ def f(conn): txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + yield self.db.runWithConnection(f) + yield self.db.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) return 1 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. @@ -722,7 +724,7 @@ def store_device(self, user_id, device_id, initial_device_display_name): return False try: - inserted = yield self._simple_insert( + inserted = yield self.db.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +738,7 @@ def store_device(self, user_id, device_id, initial_device_display_name): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self._simple_select_one_onecol( + hidden = yield self.db.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +773,7 @@ def delete_device(self, user_id, device_id): Returns: defer.Deferred """ - yield self._simple_delete_one( + yield self.db.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +791,7 @@ def delete_devices(self, user_id, device_ids): Returns: defer.Deferred """ - yield self._simple_delete_many( + yield self.db.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +820,7 @@ def update_device(self, user_id, device_id, new_display_name=None): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self._simple_update_one( + return self.db.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +831,7 @@ def update_device(self, user_id, device_id, new_display_name=None): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self._simple_delete( + yield self.db.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -853,7 +855,7 @@ def update_remote_device_list_cache_entry( Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -866,7 +868,7 @@ def _update_remote_device_list_cache_entry_txn( self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +876,7 @@ def _update_remote_device_list_cache_entry_txn( txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +892,7 @@ def _update_remote_device_list_cache_entry_txn( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -914,7 +916,7 @@ def update_remote_device_list_cache(self, user_id, devices, stream_id): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -923,11 +925,11 @@ def update_remote_device_list_cache(self, user_id, devices, stream_id): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +948,7 @@ def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id) self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -962,7 +964,7 @@ def add_device_change_to_streams(self, user_id, device_ids, hosts): (if any) should be poked. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, user_id, @@ -995,7 +997,7 @@ def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1008,7 @@ def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): context = get_active_span_text_map() - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ @@ -1069,7 +1071,7 @@ def _prune_txn(txn): return run_as_background_process( "prune_old_outbound_device_pokes", - self.runInteraction, + self.db.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index 297966d9f43c..c9e7de7d1248 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ def get_association_from_room_alias(self, room_alias): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self._simple_select_one_onecol( + room_id = yield self.db.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ def get_association_from_room_alias(self, room_alias): if not room_id: return None - servers = yield self._simple_select_onecol( + servers = yield self.db.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ def get_association_from_room_alias(self, room_alias): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ def get_room_alias_creator(self, room_alias): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ def create_room_alias_association(self, room_alias, room_id, servers, creator=No """ def alias_txn(txn): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ def alias_txn(txn): }, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -117,7 +117,9 @@ def alias_txn(txn): ) try: - ret = yield self.runInteraction("create_room_alias_association", alias_txn) + ret = yield self.db.runInteraction( + "create_room_alias_association", alias_txn + ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() @@ -126,7 +128,7 @@ def alias_txn(txn): @defer.inlineCallbacks def delete_room_alias(self, room_alias): - room_id = yield self.runInteraction( + room_id = yield self.db.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) @@ -168,6 +170,6 @@ def _update_aliases_for_room_txn(txn): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.runInteraction( + return self.db.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 113224fd7ca3..84594cf0a9bc 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -38,7 +38,7 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): StoreError """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -89,7 +89,7 @@ def add_e2e_room_keys(self, user_id, version, room_keys): } ) - yield self._simple_insert_many( + yield self.db.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @@ -125,7 +125,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): if session_id: keyvalues["session_id"] = session_id - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -170,7 +170,7 @@ def get_e2e_room_keys_multi(self, user_id, version, room_keys): Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -234,7 +234,7 @@ def count_e2e_room_keys(self, user_id, version): version (str): the version ID of the backup we're querying about """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -267,7 +267,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): if session_id: keyvalues["session_id"] = session_id - yield self._simple_delete( + yield self.db.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -312,7 +312,7 @@ def _get_e2e_room_keys_version_info_txn(txn): # it isn't there. raise StoreError(404, "No row found") - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, @@ -324,7 +324,7 @@ def _get_e2e_room_keys_version_info_txn(txn): result["etag"] = 0 return result - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @@ -352,7 +352,7 @@ def _create_e2e_room_keys_version_txn(txn): new_version = str(int(current_version) + 1) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -365,7 +365,7 @@ def _create_e2e_room_keys_version_txn(txn): return new_version - return self.runInteraction( + return self.db.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -391,7 +391,7 @@ def update_e2e_room_keys_version( updatevalues["etag"] = version_etag if updatevalues: - return self._simple_update( + return self.db.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -420,19 +420,19 @@ def _delete_e2e_room_keys_version_txn(txn): else: this_version = version - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self._simple_update_one_txn( + return self.db.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index d8ad59ad9369..38cd0ca9b821 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -48,7 +48,7 @@ def get_e2e_device_keys( if not query_list: return {} - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -125,7 +125,7 @@ def _get_e2e_device_keys_txn( ) txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) result = {} for row in rows: @@ -143,15 +143,30 @@ def _get_e2e_device_keys_txn( ) txn.execute(signature_sql, signature_query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) + # add each cross-signing signature to the correct device in the result dict. for row in rows: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] - if target_user_id in result and target_device_id in result[target_user_id]: - result[target_user_id][target_device_id].setdefault( - "signatures", {} - ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"] + signature = row["signature"] + + target_user_result = result.get(target_user_id) + if not target_user_result: + continue + + target_device_result = target_user_result.get(target_device_id) + if not target_device_result: + # note that target_device_result will be None for deleted devices. + continue + + target_device_signatures = target_device_result.setdefault("signatures", {}) + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature log_kv(result) return result @@ -171,7 +186,7 @@ def get_e2e_one_time_keys(self, user_id, device_id, key_ids): key_id) to json string for key """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -204,7 +219,7 @@ def _add_e2e_one_time_keys(txn): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -223,7 +238,7 @@ def _add_e2e_one_time_keys(txn): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.runInteraction( + yield self.db.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -246,7 +261,9 @@ def _count_e2e_one_time_keys(txn): result[algorithm] = key_count return result - return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + return self.db.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. @@ -307,7 +324,7 @@ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): Returns: dict of the key data or None if not found """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_cross_signing_key", self._get_e2e_cross_signing_key_txn, user_id, @@ -335,7 +352,7 @@ def get_all_user_signature_changes_for_remotes(self, from_key, to_key): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self._execute( + return self.db.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -352,7 +369,7 @@ def _set_e2e_device_keys_txn(txn): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self._simple_select_one_onecol_txn( + old_key_json = self.db.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -368,7 +385,7 @@ def _set_e2e_device_keys_txn(txn): log_kv({"Message": "Device key already stored."}) return False - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -377,7 +394,7 @@ def _set_e2e_device_keys_txn(txn): log_kv({"message": "Device keys stored."}) return True - return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" @@ -416,7 +433,9 @@ def _claim_e2e_one_time_keys(txn): ) return result - return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) + return self.db.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): @@ -427,12 +446,12 @@ def delete_e2e_keys_by_device_txn(txn): "user_id": user_id, } ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -441,7 +460,7 @@ def delete_e2e_keys_by_device_txn(txn): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) @@ -477,7 +496,7 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "devices", values={ @@ -490,7 +509,7 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -509,7 +528,7 @@ def set_e2e_cross_signing_key(self, user_id, key_type, key): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.runInteraction( + return self.db.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, @@ -524,7 +543,7 @@ def store_e2e_cross_signing_signatures(self, user_id, signatures): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self._simple_insert_many( + return self.db.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 90bef0cd2c9a..1f517e8fad7d 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -28,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -58,7 +59,7 @@ def get_auth_chain_ids(self, event_ids, include_given=False): Returns: list of event_ids """ - return self.runInteraction( + return self.db.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given ) @@ -90,12 +91,12 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): return list(results) def get_oldest_events_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -126,7 +127,7 @@ def get_max_depth_of(self, event_ids): Returns Deferred[int] """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +141,7 @@ def get_max_depth_of(self, event_ids): return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( + return self.db.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -188,7 +189,7 @@ def get_latest_event_ids_and_hashes_in_room(self, room_id): where *hashes* is a map from algorithm to hash. """ - return self.runInteraction( + return self.db.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, room_id, @@ -229,13 +230,13 @@ def _get_rooms_with_many_extremities_txn(txn): txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -266,12 +267,12 @@ def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ - return self.runInteraction( + return self.db.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( + min_depth = self.db.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -337,7 +338,7 @@ def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) @@ -352,7 +353,7 @@ def get_backfill_events(self, room_id, event_list, limit): limit (int) """ return ( - self.runInteraction( + self.db.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, @@ -383,7 +384,7 @@ def _get_backfill_events(self, txn, room_id, event_list, limit): queue = PriorityQueue() for event_id in event_list: - depth = self._simple_select_one_onecol_txn( + depth = self.db.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -415,7 +416,7 @@ def _get_backfill_events(self, txn, room_id, event_list, limit): @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.runInteraction( + ids = yield self.db.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -468,7 +469,7 @@ def get_successor_events(self, event_ids): Returns: Deferred[list[str]] """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -491,10 +492,10 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -508,7 +509,7 @@ def _update_min_depth_for_room_txn(self, txn, room_id, depth): if min_depth and depth >= min_depth: return - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +521,7 @@ def _handle_mult_prev_events(self, txn, events): For the given event, update the event edges table and forward and backward extremities tables. """ - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_edges", values=[ @@ -604,13 +605,13 @@ def _delete_old_forward_extrem_cache_txn(txn): return run_as_background_process( "delete_old_forward_extrem_cache", - self.runInteraction, + self.db.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -654,17 +655,17 @@ def delete_event_auth(txn): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) return min_stream_id >= target_min_stream_id - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 04ce21ac66b6..9988a6d3fca4 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -93,7 +94,7 @@ def __init__(self, db_conn, hs): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, @@ -177,7 +178,7 @@ def f(txn): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.runInteraction("get_push_action_users_in_range", f) + ret = yield self.db.runInteraction("get_push_action_users_in_range", f) return ret @defer.inlineCallbacks @@ -229,7 +230,7 @@ def get_after_receipt(txn): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -257,7 +258,7 @@ def get_no_receipt(txn): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -329,7 +330,7 @@ def get_after_receipt(txn): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -357,7 +358,7 @@ def get_no_receipt(txn): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -407,7 +408,7 @@ def _get_if_maybe_push_in_range_for_user_txn(txn): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) @@ -441,7 +442,7 @@ def _gen_entry(user_id, actions): ) def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead + # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. sql = """ @@ -458,7 +459,7 @@ def _add_push_actions_to_staging_txn(txn): ), ) - return self.runInteraction( + return self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) @@ -472,7 +473,7 @@ def remove_push_actions_from_staging(self, event_id): """ try: - res = yield self._simple_delete( + res = yield self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -489,7 +490,7 @@ def remove_push_actions_from_staging(self, event_id): def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.runInteraction, + self.db.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -525,7 +526,7 @@ def find_first_stream_ordering_after_ts(self, ts): Deferred[int]: stream ordering of the first event received on/after the timestamp """ - return self.runInteraction( + return self.db.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -611,17 +612,17 @@ def _find_first_stream_ordering_after_ts_txn(txn, ts): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", @@ -677,7 +678,7 @@ def _set_push_actions_for_event_and_users_txn( ) for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( + user_ids = self.db.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -727,9 +728,9 @@ def f(txn): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - push_actions = yield self.runInteraction("get_push_actions_for_user", f) + push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions @@ -748,7 +749,7 @@ def f(txn): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.runInteraction("get_time_of_last_push_action_before", f) + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @defer.inlineCallbacks @@ -757,7 +758,9 @@ def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) + result = yield self.db.runInteraction( + "get_latest_push_action_stream_ordering", f + ) return result[0] or 0 def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): @@ -830,7 +833,7 @@ def _rotate_notifs(self): while True: logger.info("Rotating notifications") - caught_up = yield self.runInteraction( + caught_up = yield self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: @@ -844,7 +847,7 @@ def _rotate_notifs_txn(self, txn): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +883,7 @@ def _rotate_notifs_txn(self, txn): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +915,7 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2737a1d3aeca..998bba1aad62 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -38,10 +38,10 @@ from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -94,13 +94,10 @@ def f(self, *args, **kwargs): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, + StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsStore, self).__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count @@ -130,6 +127,8 @@ def _censor_redactions(): if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -141,7 +140,7 @@ def fetch(txn): ) return txn.fetchall() - res = yield self.runInteraction("read_forward_extremities", fetch) + res = yield self.db.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) @_retry_on_integrity_error @@ -206,7 +205,7 @@ def _persist_events_and_state_updates( for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.runInteraction( + yield self.db.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -279,7 +278,7 @@ def _get_events_which_are_prevs_txn(txn, batch): results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) @@ -343,7 +342,7 @@ def _get_prevs_before_rejected_txn(txn, batch): existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) @@ -430,7 +429,7 @@ def _persist_events_txn( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -578,12 +577,12 @@ def _update_forward_extremities_txn( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -596,7 +595,7 @@ def _update_forward_extremities_txn( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -720,7 +719,7 @@ def _update_outliers_txn(self, txn, events_and_contexts): # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -792,7 +791,7 @@ def event_dict(event): d.pop("redacted_because", None) return d - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_json", values=[ @@ -809,7 +808,7 @@ def event_dict(event): ], ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="events", values=[ @@ -839,7 +838,7 @@ def event_dict(event): # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self._simple_update_txn( + self.db.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -940,6 +939,12 @@ def _update_metadata_tables_txn( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -975,7 +980,7 @@ def _update_metadata_tables_txn( state_values.append(vals) - self._simple_insert_many_txn(txn, table="state_events", values=state_values) + self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1006,7 +1011,7 @@ def _add_to_cache(self, txn, events_and_contexts): ) txn.execute(sql + clause, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: @@ -1024,7 +1029,7 @@ def _store_redaction(self, txn, event): # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="redactions", values={ @@ -1034,20 +1039,25 @@ def _store_redaction(self, txn, event): }, ) - @defer.inlineCallbacks - def _censor_redactions(self): + async def _censor_redactions(self): """Censors all redactions older than the configured period that haven't been censored yet. By censor we mean update the event_json table with the redacted event. - - Returns: - Deferred """ if self.hs.config.redaction_retention_period is None: return + if not ( + await self.db.updates.has_completed_background_update( + "redactions_have_censored_ts_idx" + ) + ): + # We don't want to run this until the appropriate index has been + # created. + return + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period # We fetch all redactions that: @@ -1069,15 +1079,15 @@ def _censor_redactions(self): LIMIT ? """ - rows = yield self._execute( + rows = await self.db.execute( "_censor_redactions_fetch", None, sql, before_ts, 100 ) updates = [] for redaction_id, event_id in rows: - redaction_event = yield self.get_event(redaction_id, allow_none=True) - original_event = yield self.get_event( + redaction_event = await self.get_event(redaction_id, allow_none=True) + original_event = await self.get_event( event_id, allow_rejected=True, allow_none=True ) @@ -1101,21 +1111,32 @@ def _censor_redactions(self): def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, updatevalues={"have_censored": True}, ) - yield self.runInteraction("_update_censor_txn", _update_censor_txn) + await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.db.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) @defer.inlineCallbacks def count_daily_messages(self): @@ -1136,7 +1157,7 @@ def _count_messages(txn): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_messages", _count_messages) + ret = yield self.db.runInteraction("count_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1157,7 +1178,7 @@ def _count_messages(txn): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1172,7 +1193,7 @@ def _count(txn): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_active_rooms", _count) + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret def get_current_backfill_token(self): @@ -1224,7 +1245,7 @@ def get_all_new_forward_event_rows(txn): return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1269,7 +1290,7 @@ def get_all_new_backfill_event_rows(txn): return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @@ -1362,7 +1383,7 @@ def get_all_new_events_txn(txn): backward_ex_outliers, ) - return self.runInteraction("get_all_new_events", get_all_new_events_txn) + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) def purge_history(self, room_id, token, delete_local_events): """Deletes room history before a certain point @@ -1382,7 +1403,7 @@ def purge_history(self, room_id, token, delete_local_events): deleted events. """ - return self.runInteraction( + return self.db.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -1630,7 +1651,7 @@ def purge_room(self, room_id): Deferred[List[int]]: The list of state groups to delete. """ - return self.runInteraction("purge_room", self._purge_room_txn, room_id) + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before @@ -1749,7 +1770,7 @@ def purge_unreferenced_state_groups( to delete. """ - return self.runInteraction( + return self.db.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -1761,7 +1782,7 @@ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete) "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1788,15 +1809,15 @@ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete) curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1833,7 +1854,7 @@ def get_previous_state_groups(self, state_groups): state group. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1852,7 +1873,7 @@ def purge_room_state(self, room_id, state_groups_to_delete): state_groups_to_delete (list[int]): State groups to delete """ - return self.runInteraction( + return self.db.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, @@ -1863,7 +1884,7 @@ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1874,7 +1895,7 @@ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1885,7 +1906,7 @@ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1902,7 +1923,7 @@ async def is_event_after(self, event_id1, event_id2): @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( + res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1925,7 +1946,7 @@ def get_all_updated_current_state_deltas_txn(txn): txn.execute(sql, (from_token, to_token, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1943,7 +1964,7 @@ def insert_labels_for_event_txn( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self._simple_insert_many_txn( + return self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1957,6 +1978,101 @@ def insert_labels_for_event_txn( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self.db.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.db.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index aa87f9abc538..5177b7101600 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -22,30 +22,30 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): +class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -56,7 +56,7 @@ def __init__(self, db_conn, hs): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -65,16 +65,16 @@ def __init__(self, db_conn, hs): psql_only=True, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -82,14 +82,22 @@ def __init__(self, db_conn, hs): where_clause="have_censored", ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) + self.db.updates.register_background_index_update( + "redactions_have_censored_ts_idx", + index_name="redactions_have_censored_ts", + table="redactions", + columns=["received_ts"], + where_clause="NOT have_censored", + ) + @defer.inlineCallbacks def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -145,18 +153,20 @@ def reindex_txn(txn): "rows_inserted": rows_inserted + len(rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) return result @@ -189,7 +199,7 @@ def reindex_search_txn(txn): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self._simple_select_many_txn( + ev_rows = self.db.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -222,18 +232,20 @@ def reindex_search_txn(txn): "rows_inserted": rows_inserted + len(rows_to_update), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + yield self.db.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) return result @@ -366,7 +378,7 @@ def _cleanup_extremities_bg_update_txn(txn): to_delete.intersection_update(original_set) - deleted = self._simple_delete_many_txn( + deleted = self.db.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +394,7 @@ def _cleanup_extremities_bg_update_txn(txn): if deleted: # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +408,7 @@ def _cleanup_extremities_bg_update_txn(txn): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -406,17 +418,19 @@ def _cleanup_extremities_bg_update_txn(txn): return len(original_set) - num_handled = yield self.runInteraction( + num_handled = yield self.db.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + yield self.db.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.runInteraction( + yield self.db.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) @@ -464,18 +478,18 @@ def _redactions_received_ts_txn(txn): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) return len(rows) - count = yield self.runInteraction( + count = yield self.db.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) if not count: - yield self._end_background_update("redactions_received_ts") + yield self.db.updates._end_background_update("redactions_received_ts") return count @@ -501,11 +515,11 @@ def _event_fix_redactions_bytes_txn(txn): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.runInteraction( + yield self.db.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self._end_background_update("event_fix_redactions_bytes") + yield self.db.updates._end_background_update("event_fix_redactions_bytes") return 1 @@ -533,7 +547,7 @@ def _event_store_labels_txn(txn): try: event_json = json.loads(event_json_raw) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -559,17 +573,17 @@ def _event_store_labels_txn(txn): nbrows += 1 last_row_event_id = event_id - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) return nbrows - num_rows = yield self.runInteraction( + num_rows = yield self.db.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) if not num_rows: - yield self._end_background_update("event_store_labels") + yield self.db.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index de862f4141c7..2c9142814cbe 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -17,6 +17,7 @@ import itertools import logging +import threading from collections import namedtuple from typing import List, Optional @@ -34,8 +35,10 @@ from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_domain_from_id from synapse.util import batch_iter +from synapse.util.caches.descriptors import Cache from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -65,6 +68,17 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) + + self._get_event_cache = Cache( + "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + ) + + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. @@ -77,7 +91,7 @@ def get_received_ts(self, event_id): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -116,7 +130,7 @@ def _get_approximate_received_ts_txn(txn): return ts - return self.runInteraction( + return self.db.runInteraction( "get_approximate_received_ts", _get_approximate_received_ts_txn ) @@ -462,7 +476,7 @@ def _fetch_event_list(self, conn, event_list): event_id for events, _ in event_list for event_id in events ) - row_dict = self._new_transaction( + row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -594,7 +608,7 @@ def _enqueue_events(self, events): if should_start: run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch + "fetch_events", self.db.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) @@ -755,7 +769,7 @@ def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -790,42 +804,10 @@ def have_seen_events_txn(txn, chunk): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) - return results - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" + yield self.db.runInteraction( + "have_seen_events", have_seen_events_txn, chunk ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_seen_events_with_rejections", f) + return results def _get_total_state_event_counts_txn(self, txn, room_id): """ @@ -851,7 +833,7 @@ def get_total_state_event_counts(self, room_id): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id, @@ -876,7 +858,7 @@ def get_current_state_event_counts(self, room_id): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index f05ace299afe..342d6622a458 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ def get_user_filter(self, user_localpart, filter_id): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self._simple_select_one_onecol( + def_json = yield self.db.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -71,4 +71,4 @@ def _do_txn(txn): return filter_id - return self.runInteraction("add_user_filter", _do_txn) + return self.db.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 5ded539af80d..6acd45e9f36e 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ def set_group_join_policy(self, group_id, join_policy): * "invite" * "open" """ - return self._simple_update_one( + return self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ def set_group_join_policy(self, group_id, join_policy): ) def get_group(self, group_id): - return self._simple_select_one( + return self.db.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ def get_users_in_group(self, group_id, include_private=False): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.db.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ def get_users_in_group(self, group_id, include_private=False): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ def get_rooms_in_group(self, group_id, include_private=False): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.db.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -153,10 +153,12 @@ def _get_rooms_for_summary_txn(txn): return rooms, categories - return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) + return self.db.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -180,7 +182,7 @@ def _add_room_to_summary_txn( an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self._simple_select_one_onecol_txn( + room_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +195,7 @@ def _add_room_to_summary_txn( if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +206,7 @@ def _add_room_to_summary_txn( raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +226,7 @@ def _add_room_to_summary_txn( (group_id, category_id, group_id, category_id), ) - existing = self._simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +259,7 @@ def _add_room_to_summary_txn( to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +273,7 @@ def _add_room_to_summary_txn( if is_public is None: is_public = True - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +289,7 @@ def remove_room_from_summary(self, group_id, room_id, category_id): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self._simple_delete( + return self.db.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +301,7 @@ def remove_room_from_summary(self, group_id, room_id, category_id): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +318,7 @@ def get_group_categories(self, group_id): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( + category = yield self.db.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +345,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.db.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +354,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public): ) def remove_group_category(self, group_id, category_id): - return self._simple_delete( + return self.db.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +362,7 @@ def remove_group_category(self, group_id, category_id): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +379,7 @@ def get_group_roles(self, group_id): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( + role = yield self.db.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +406,7 @@ def upsert_group_role(self, group_id, role_id, profile, is_public): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.db.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,14 +415,14 @@ def upsert_group_role(self, group_id, role_id, profile, is_public): ) def remove_group_role(self, group_id, role_id): - return self._simple_delete( + return self.db.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -444,7 +446,7 @@ def _add_user_to_summary_txn( an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self._simple_select_one_onecol_txn( + user_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +459,7 @@ def _add_user_to_summary_txn( if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +470,7 @@ def _add_user_to_summary_txn( raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +490,7 @@ def _add_user_to_summary_txn( (group_id, role_id, group_id, role_id), ) - existing = self._simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +519,7 @@ def _add_user_to_summary_txn( to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +533,7 @@ def _add_user_to_summary_txn( if is_public is None: is_public = True - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +549,7 @@ def remove_user_from_summary(self, group_id, user_id, role_id): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self._simple_delete( + return self.db.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +563,7 @@ def get_local_groups_for_room(self, room_id): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -625,12 +627,12 @@ def _get_users_for_summary_txn(txn): return users, roles - return self.runInteraction( + return self.db.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +641,7 @@ def is_user_in_group(self, user_id, group_id): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +652,7 @@ def is_user_admin_in_group(self, group_id, user_id): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self._simple_insert( + return self.db.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +661,7 @@ def add_group_invite(self, group_id, user_id): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +684,7 @@ def get_users_membership_info_in_group(self, group_id, user_id): """ def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +699,7 @@ def _get_users_membership_in_group_txn(txn): "is_privileged": row["is_admin"], } - row = self._simple_select_one_onecol_txn( + row = self.db.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -710,7 +712,7 @@ def _get_users_membership_in_group_txn(txn): return {} - return self.runInteraction( + return self.db.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -738,7 +740,7 @@ def add_user_to_group( """ def _add_user_to_group_txn(txn): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +751,14 @@ def _add_user_to_group_txn(txn): }, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +768,7 @@ def _add_user_to_group_txn(txn): }, ) if remote_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -777,49 +779,49 @@ def _add_user_to_group_txn(txn): }, ) - return self.runInteraction("add_user_to_group", _add_user_to_group_txn) + return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + return self.db.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( + return self.db.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,26 +830,26 @@ def update_room_in_group_visibility(self, group_id, room_id, is_public): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +859,7 @@ def get_publicised_groups_for_user(self, user_id): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self._simple_update_one( + return self.db.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +895,12 @@ def register_user_group_membership( def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +913,7 @@ def _register_user_group_membership_txn(txn, next_id): }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +932,7 @@ def _register_user_group_membership_txn(txn, next_id): if membership == "join": if local_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +942,7 @@ def _register_user_group_membership_txn(txn, next_id): }, ) if remote_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +953,12 @@ def _register_user_group_membership_txn(txn, next_id): }, ) else: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -965,7 +967,7 @@ def _register_user_group_membership_txn(txn, next_id): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, @@ -976,7 +978,7 @@ def _register_user_group_membership_txn(txn, next_id): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self._simple_insert( + yield self.db.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +993,7 @@ def create_group( @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self._simple_update_one( + yield self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1008,16 +1010,16 @@ def _get_attestations_need_renewals_txn(txn): WHERE valid_until_ms <= ? """ txn.execute(sql, (valid_until_ms,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self._simple_update_one( + return self.db.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1029,7 @@ def update_attestation_renewal(self, group_id, user_id, attestation): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self._simple_update_one( + return self.db.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1048,7 @@ def remove_attestation_renewal(self, group_id, user_id): group_id (str) user_id (str) """ - return self._simple_delete( + return self.db.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1059,7 @@ def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self._simple_select_one( + row = yield self.db.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1074,7 @@ def get_remote_attestation(self, group_id, user_id): return None def get_joined_groups(self, user_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1099,7 +1101,7 @@ def _get_all_groups_for_user_txn(txn): for row in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -1109,7 +1111,7 @@ def get_groups_changes_for_user(self, user_id, from_token, to_token): user_id, from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_groups_changes_for_user_txn(txn): sql = """ @@ -1129,7 +1131,7 @@ def _get_groups_changes_for_user_txn(txn): for group_id, membership, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1139,7 +1141,7 @@ def get_all_groups_changes(self, from_token, to_token, limit): from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_all_groups_changes_txn(txn): sql = """ @@ -1154,7 +1156,7 @@ def _get_all_groups_changes_txn(txn): for stream_id, group_id, user_id, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) @@ -1188,8 +1190,8 @@ def _delete_group_txn(txn): ] for table in tables: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction("delete_group", _delete_group_txn) + return self.db.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ebc7db3ed693..6b12f5a75fe2 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -92,7 +92,7 @@ def _txn(txn): _get_keys(txn, batch) return keys - return self.runInteraction("get_server_verify_keys", _txn) + return self.db.runInteraction("get_server_verify_keys", _txn) def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): """Stores NACL verification keys for remote servers. @@ -127,9 +127,9 @@ def _invalidate(res): f((i,)) return res - return self.runInteraction( + return self.db.runInteraction( "store_server_verify_keys", - self._simple_upsert_many_txn, + self.db.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ def store_server_keys_json( ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self._simple_upsert( + return self.db.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ def _get_server_keys_json_txn(txn): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, @@ -211,4 +211,4 @@ def _get_server_keys_json_txn(txn): results[(server_name, key_id, from_server)] = rows return results - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) + return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 0f2887bdcea7..80ca36dedfa9 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -12,14 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database -class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) - self.register_background_index_update( + self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", @@ -31,15 +34,15 @@ def __init__(self, db_conn, hs): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return self._simple_select_one( + return self.db.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +67,7 @@ def store_local_media( user_id, url_cache=None, ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository", { "media_id": media_id, @@ -124,12 +127,12 @@ def get_url_cache_txn(txn): ) ) - return self.runInteraction("get_url_cache", get_url_cache_txn) + return self.db.runInteraction("get_url_cache", get_url_cache_txn) def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +147,7 @@ def store_url_cache( ) def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( + return self.db.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +169,7 @@ def store_local_thumbnail( thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +183,7 @@ def store_local_thumbnail( ) def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( + return self.db.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +208,7 @@ def store_cached_remote_media( upload_name, filesystem_id, ): - return self._simple_insert( + return self.db.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -250,10 +253,12 @@ def update_cache_txn(txn): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.runInteraction("update_cached_last_access_time", update_cache_txn) + return self.db.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( + return self.db.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +283,7 @@ def store_remote_media_thumbnail( thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.db.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,24 +305,24 @@ def get_remote_media_before(self, before_ts): " WHERE last_access_ts < ?" ) - return self._execute( - "get_remote_media_before", self.cursor_to_dict, sql, before_ts + return self.db.execute( + "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.runInteraction("delete_remote_media", delete_remote_media_txn) + return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) def get_expired_url_cache(self, now_ts): sql = ( @@ -331,7 +336,9 @@ def _get_expired_url_cache_txn(txn): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + return self.db.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) def delete_url_cache(self, media_ids): if len(media_ids) == 0: @@ -342,7 +349,7 @@ def delete_url_cache(self, media_ids): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( @@ -356,7 +363,7 @@ def _get_url_cache_media_before_txn(txn): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.runInteraction( + return self.db.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) @@ -373,6 +380,6 @@ def _delete_url_cache_media_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction( + return self.db.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b41c3d317a49..27158534cbbb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -27,13 +28,13 @@ class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self._new_transaction( - dbconn, + self.db.new_transaction( + db_conn, "initialise_mau_threepids", [], [], @@ -146,7 +147,7 @@ def _reap_users(txn, reserved_users): txn.execute(sql, query_args) reserved_users = yield self.get_registered_reserved_users() - yield self.runInteraction( + yield self.db.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) # It seems poor to invalidate the whole cache, Postgres supports @@ -174,7 +175,7 @@ def _count_users(txn): (count,) = txn.fetchone() return count - return self.runInteraction("count_users", _count_users) + return self.db.runInteraction("count_users", _count_users) @defer.inlineCallbacks def get_registered_reserved_users(self): @@ -217,7 +218,7 @@ def upsert_monthly_active_user(self, user_id): if is_support: return - yield self.runInteraction( + yield self.db.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -261,7 +262,7 @@ def upsert_monthly_active_user_txn(self, txn, user_id): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( + is_insert = self.db.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +282,7 @@ def user_last_seen_monthly_active(self, user_id): """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 79b40044d923..cc21437e920e 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( + return self.db.simple_insert( table="open_id_tokens", values={ "token": token, @@ -28,4 +28,6 @@ def get_user_id_for_token_txn(txn): else: return rows[0][0] - return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) + return self.db.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 523ed6575eab..a2c83e08675d 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -29,7 +29,7 @@ def update_presence(self, presence_states): ) with stream_ordering_manager as stream_orderings: - yield self.runInteraction( + yield self.db.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, @@ -46,7 +46,7 @@ def _update_presence_txn(self, txn, stream_orderings, presence_states): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -88,7 +88,7 @@ def get_all_presence_updates_txn(txn): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) @@ -103,7 +103,7 @@ def _get_presence_for_user(self, user_id): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ def get_current_presence_token(self): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( + return self.db.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ def allow_presence_visible(self, observed_localpart, observer_userid): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( + return self.db.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index e4e8a1c1d6e6..2b52cf9c1a6d 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self._simple_select_one( + profile = yield self.db.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ def get_profileinfo(self, user_localpart): ) def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ def get_profile_displayname(self, user_localpart): ) def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ def get_profile_avatar_url(self, user_localpart): ) def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ def get_from_remote_profile_cache(self, user_id): ) def create_profile(self, user_localpart): - return self._simple_insert( + return self.db.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ def set_profile_displayname(self, user_localpart, new_displayname): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ def add_remote_profile_cache(self, user_id, displayname, avatar_url): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self._simple_upsert( + return self.db.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ def add_remote_profile_cache(self, user_id, displayname, avatar_url): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( + return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ def maybe_delete_remote_profile_cache(self, user_id): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self._simple_delete( + yield self.db.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -144,9 +144,9 @@ def _get_remote_profile_cache_entries_that_expire_txn(txn): txn.execute(sql, (last_checked,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) @@ -155,7 +155,7 @@ def _get_remote_profile_cache_entries_that_expire_txn(txn): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ def is_subscribed_remote_profile_for_user(self, user_id): if res: return True - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b520062d8445..5ba13aa973a0 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -72,10 +73,10 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) - push_rules_prefill, push_rules_id = self._get_cache_dict( + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +101,7 @@ def get_max_push_rules_stream_id(self): @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +125,7 @@ def get_push_rules_for_user(self, user_id): @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( + results = yield self.db.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -146,7 +147,7 @@ def have_push_rules_changed_txn(txn): (count,) = txn.fetchone() return bool(count) - return self.runInteraction( + return self.db.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @@ -162,7 +163,7 @@ def bulk_get_push_rules(self, user_ids): results = {user_id: [] for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +321,7 @@ def bulk_get_push_rules_enabled(self, user_ids): results = {user_id: {} for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -350,7 +351,7 @@ def add_push_rule( with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -364,7 +365,7 @@ def add_push_rule( after, ) else: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -395,7 +396,7 @@ def _add_push_rule_relative_txn( relative_to_rule = before or after - res = self._simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -499,7 +500,7 @@ def _upsert_push_rule_txn( actions_json, update_stream=True, ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -518,7 +519,7 @@ def _upsert_push_rule_txn( # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +562,7 @@ def delete_push_rule(self, user_id, rule_id): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -571,7 +572,7 @@ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, @@ -582,7 +583,7 @@ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): def set_push_rule_enabled(self, user_id, rule_id, enabled): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -596,7 +597,7 @@ def _set_push_rule_enabled_txn( self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +637,7 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): update_stream=False, ) else: - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -655,7 +656,7 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -675,7 +676,7 @@ def _insert_push_rules_update_txn( if data is not None: values.update(data) - self._simple_insert_txn(txn, "push_rules_stream", values=values) + self.db.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) @@ -699,7 +700,7 @@ def get_all_push_rule_updates_txn(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d76861cdc0b5..f07309ef093b 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ def _decode_pushers_rows(self, rows): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( + ret = yield self.db.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ def get_pushers_by_user_id(self, user_id): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( + ret = yield self.db.simple_select_list( "pushers", keyvalues, [ @@ -100,11 +100,11 @@ def get_pushers_by(self, keyvalues): def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.runInteraction("get_all_pushers", get_pushers) + rows = yield self.db.runInteraction("get_all_pushers", get_pushers) return rows def get_all_updated_pushers(self, last_id, current_id, limit): @@ -134,7 +134,7 @@ def get_all_updated_pushers_txn(txn): return updated, deleted - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn ) @@ -177,7 +177,7 @@ def get_all_updated_pushers_rows_txn(txn): return results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -193,7 +193,7 @@ def get_if_user_has_pusher(self, user_id): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -229,8 +229,8 @@ def add_pusher( ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.db.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -255,7 +255,7 @@ def add_pusher( if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.runInteraction( + yield self.db.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, @@ -269,7 +269,7 @@ def delete_pusher_txn(txn, stream_id): txn, self.get_if_user_has_pusher, (user_id,) ) - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ def delete_pusher_txn(txn, stream_id): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -290,13 +290,13 @@ def delete_pusher_txn(txn, stream_id): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) @defer.inlineCallbacks def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self._simple_update_one( + yield self.db.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ def update_pusher_last_stream_ordering_and_success( Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self._simple_update( + updated = yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ def update_pusher_last_stream_ordering_and_success( @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( + yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( + res = yield self.db.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -361,8 +361,8 @@ def get_throttle_params_by_room(self, pusher_id): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( + # (pusher, room_id) so simple_upsert will retry + yield self.db.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 8b17334ff434..96e54d145ebf 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,6 +22,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -61,7 +62,7 @@ def get_users_with_read_receipts_in_room(self, room_id): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( + return self.db.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +71,7 @@ def get_receipts_for_room(self, room_id, receipt_type): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +85,7 @@ def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -108,7 +109,7 @@ def f(txn): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], @@ -187,11 +188,11 @@ def f(txn): txn.execute(sql, (room_id, to_key)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return rows - rows = yield self.runInteraction("get_linearized_receipts_for_room", f) + rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -237,9 +238,11 @@ def f(txn): txn.execute(sql + clause, [to_key] + list(args)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) + txn_results = yield self.db.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) results = {} for row in txn_results: @@ -282,7 +285,7 @@ def get_all_updated_receipts_txn(txn): return list(r[0:5] + (json.loads(r[5]),) for r in txn) - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) @@ -313,14 +316,14 @@ def _invalidate_get_users_with_receipts_in_room( class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -335,7 +338,7 @@ def insert_linearized_receipt_txn( otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self._simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +391,7 @@ def insert_linearized_receipt_txn( (user_id, room_id, receipt_type), ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +401,7 @@ def insert_linearized_receipt_txn( }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -453,13 +456,13 @@ def graph_to_linear(txn): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.runInteraction( + linearized_event_id = yield self.db.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.runInteraction( + event_ts = yield self.db.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -488,7 +491,7 @@ def graph_to_linear(txn): return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.runInteraction( + return self.db.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -514,7 +517,7 @@ def insert_graph_receipt_txn( self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +526,7 @@ def insert_graph_receipt_txn( "user_id": user_id, }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 98cf6427c315..5e8ecac0ea73 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -26,8 +26,8 @@ from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -37,15 +37,15 @@ class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @cached() def get_user_by_id(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -94,7 +94,7 @@ def get_user_by_access_token(self, token): including the keys `name`, `is_guest`, `device_id`, `token_id`, `valid_until_ms`. """ - return self.runInteraction( + return self.db.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -109,7 +109,7 @@ def get_expiration_ts_for_user(self, user_id): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +137,7 @@ def set_account_validity_for_user( """ def set_account_validity_for_user_txn(txn): - self._simple_update_txn( + self.db.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -151,7 +151,7 @@ def set_account_validity_for_user_txn(txn): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.runInteraction( + yield self.db.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) @@ -167,7 +167,7 @@ def set_renewal_token_for_user(self, user_id, renewal_token): Raises: StoreError: The provided token is already set for another user. """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +184,7 @@ def get_user_from_renewal_token(self, renewal_token): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +203,7 @@ def get_renewal_token_for_user(self, user_id): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -229,9 +229,9 @@ def select_users_txn(txn, now_ms, renew_at): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - res = yield self.runInteraction( + res = yield self.db.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), @@ -250,7 +250,7 @@ def set_renewal_mail_status(self, user_id, email_sent): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,7 +265,7 @@ def delete_account_validity_for_user(self, user_id): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self._simple_delete_one( + yield self.db.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -281,7 +281,7 @@ def is_server_admin(self, user): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -299,7 +299,7 @@ def set_server_admin(self, user, admin): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self._simple_update_one( + return self.db.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -316,7 +316,7 @@ def _query_for_auth(self, txn, token): ) txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0] @@ -332,7 +332,9 @@ def is_real_user(self, user_id): Returns: Deferred[bool]: True if user 'user_type' is null or empty string """ - res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + res = yield self.db.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) return res @cachedInlineCallbacks() @@ -345,13 +347,13 @@ def is_support_user(self, user_id): Returns: Deferred[bool]: True if user is of type UserTypes.SUPPORT """ - res = yield self.runInteraction( + res = yield self.db.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) return res def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +363,7 @@ def is_real_user_txn(self, txn, user_id): return res is None def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -380,7 +382,7 @@ def f(txn): txn.execute(sql, (user_id,)) return dict(txn) - return self.runInteraction("get_users_by_id_case_insensitive", f) + return self.db.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -394,7 +396,7 @@ async def get_user_by_external_id( Returns: str|None: the mxid of the user, or None if they are not known """ - return await self._simple_select_one_onecol( + return await self.db.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -408,12 +410,12 @@ def count_all_users(self): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret def count_daily_user_type(self): @@ -445,7 +447,7 @@ def _count_daily_user_type(txn): results[row[0]] = row[1] return results - return self.runInteraction("count_daily_user_type", _count_daily_user_type) + return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): @@ -459,7 +461,7 @@ def _count_users(txn): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret @defer.inlineCallbacks @@ -468,12 +470,12 @@ def count_real_users(self): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_real_users", _count_users) + ret = yield self.db.runInteraction("count_real_users", _count_users) return ret @defer.inlineCallbacks @@ -503,7 +505,7 @@ def _find_next_generated_user_id(txn): return ( ( - yield self.runInteraction( + yield self.db.runInteraction( "find_next_generated_user_id", _find_next_generated_user_id ) ) @@ -520,7 +522,7 @@ def get_user_id_by_threepid(self, medium, address): Returns: Deferred[str|None]: user id or None if no user id/threepid mapping exists """ - user_id = yield self.runInteraction( + user_id = yield self.db.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -536,7 +538,7 @@ def get_user_id_by_threepid_txn(self, txn, medium, address): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self._simple_select_one_txn( + ret = self.db.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -549,7 +551,7 @@ def get_user_id_by_threepid_txn(self, txn, medium, address): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( + yield self.db.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -557,7 +559,7 @@ def user_add_threepid(self, user_id, medium, address, validated_at, added_at): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( + ret = yield self.db.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -566,7 +568,7 @@ def user_get_threepids(self, user_id): return ret def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -579,7 +581,7 @@ def user_delete_threepids(self, user_id: str): user_id: The user id to delete all threepids of """ - return self._simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -601,7 +603,7 @@ def add_user_bound_threepid(self, user_id, medium, address, id_server): """ # We need to use an upsert, in case they user had already bound the # threepid - return self._simple_upsert( + return self.db.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -627,7 +629,7 @@ def user_get_bound_threepids(self, user_id): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self._simple_select_list( + return self.db.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -648,7 +650,7 @@ def remove_user_bound_threepid(self, user_id, medium, address, id_server): Returns: Deferred """ - return self._simple_delete( + return self.db.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -671,7 +673,7 @@ def get_id_servers_user_bound(self, user_id, medium, address): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -689,7 +691,7 @@ def get_user_deactivated_status(self, user_id): defer.Deferred(bool): The requested value. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -756,13 +758,13 @@ def get_threepid_validation_session_txn(txn): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.runInteraction( + return self.db.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) @@ -776,39 +778,37 @@ def delete_threepid_session(self, session_id): """ def delete_threepid_session_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) -class RegistrationBackgroundUpdateStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.register_background_index_update( + self.db.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -818,13 +818,13 @@ def __init__(self, db_conn, hs): # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") + self.db.updates.register_noop_background_update("refresh_tokens_device_index") - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) @@ -857,7 +857,7 @@ def _background_update_set_deactivated_flag_txn(txn): (last_user, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True, 0 @@ -871,7 +871,7 @@ def _background_update_set_deactivated_flag_txn(txn): logger.info("Marked %d rows as deactivated", rows_processed_nb) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -880,12 +880,12 @@ def _background_update_set_deactivated_flag_txn(txn): else: return False, len(rows) - end, nb_processed = yield self.runInteraction( + end, nb_processed = yield self.db.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self._end_background_update("users_set_deactivated_flag") + yield self.db.updates._end_background_update("users_set_deactivated_flag") return nb_processed @@ -911,21 +911,29 @@ def _bg_user_threepids_grandfather_txn(txn): txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.runInteraction( + yield self.db.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self._end_background_update("user_threepids_grandfather") + yield self.db.updates._end_background_update("user_threepids_grandfather") return 1 class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens def start_cull(): # run as a background process to make sure that the database transactions @@ -953,7 +961,7 @@ def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): """ next_id = self._access_tokens_id_gen.get_next() - yield self._simple_insert( + yield self.db.simple_insert( "access_tokens", { "id": next_id, @@ -995,7 +1003,7 @@ def register_user( Raises: StoreError if the user_id could not be registered. """ - return self.runInteraction( + return self.db.runInteraction( "register_user", self._register_user, user_id, @@ -1029,7 +1037,7 @@ def _register_user( # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self._simple_select_one_txn( + self.db.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1037,7 +1045,7 @@ def _register_user( allow_none=False, ) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1051,7 +1059,7 @@ def _register_user( }, ) else: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "users", values={ @@ -1106,7 +1114,7 @@ def record_user_external_id( external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._simple_insert( + return self.db.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1124,12 +1132,14 @@ def user_set_password_hash(self, user_id, password_hash): """ def user_set_password_hash_txn(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) + return self.db.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -1144,7 +1154,7 @@ def user_set_consent_version(self, user_id, consent_version): """ def f(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1152,7 +1162,7 @@ def f(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_version", f) + return self.db.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): """Updates the user table to record that we have sent the user a server @@ -1168,7 +1178,7 @@ def user_set_consent_server_notice_sent(self, user_id, consent_version): """ def f(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1176,7 +1186,7 @@ def f(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_server_notice_sent", f) + return self.db.runInteraction("user_set_consent_server_notice_sent", f) def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ @@ -1222,11 +1232,11 @@ def f(txn): return tokens_and_devices - return self.runInteraction("user_delete_access_tokens", f) + return self.db.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1234,11 +1244,11 @@ def f(txn): txn, self.get_user_by_access_token, (access_token,) ) - return self.runInteraction("delete_access_token", f) + return self.db.runInteraction("delete_access_token", f) @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1253,7 +1263,7 @@ def add_user_pending_deactivation(self, user_id): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self._simple_insert( + return self.db.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1266,7 +1276,7 @@ def del_user_pending_deactivation(self, user_id): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( + return self.db.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1277,7 +1287,7 @@ def get_user_pending_deactivation(self): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1307,7 +1317,7 @@ def validate_threepid_session(self, session_id, client_secret, token, current_ts # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1325,7 +1335,7 @@ def validate_threepid_session_txn(txn): 400, "This client_secret does not match the provided session_id" ) - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1350,7 +1360,7 @@ def validate_threepid_session_txn(txn): ) # Looks good. Validate the session - self._simple_update_txn( + self.db.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1360,7 +1370,7 @@ def validate_threepid_session_txn(txn): return next_link # Return next_link if it exists - return self.runInteraction( + return self.db.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) @@ -1393,7 +1403,7 @@ def upsert_threepid_validation_session( if validated_at: insertion_values["validated_at"] = validated_at - return self._simple_upsert( + return self.db.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1431,7 +1441,7 @@ def start_or_continue_validation_session( def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1444,7 +1454,7 @@ def start_or_continue_validation_session_txn(txn): ) # Create a new validation token with this session ID - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1455,7 +1465,7 @@ def start_or_continue_validation_session_txn(txn): }, ) - return self.runInteraction( + return self.db.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) @@ -1470,7 +1480,7 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts): """ return txn.execute(sql, (ts,)) - return self.runInteraction( + return self.db.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), @@ -1485,7 +1495,7 @@ def set_user_deactivated_status(self, user_id, deactivated): deactivated (bool): The value to set for `deactivated`. """ - yield self.runInteraction( + yield self.db.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1493,7 +1503,7 @@ def set_user_deactivated_status(self, user_id, deactivated): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1502,3 +1512,59 @@ def set_user_deactivated_status_txn(self, txn, user_id, deactivated): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + yield self.db.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 7d5de0ea2ea4..1c07c7a425fc 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ def _store_rejections_txn(self, txn, event_id, reason): ) def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 858f65582bce..046c2b4845bb 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -129,7 +129,7 @@ def _get_recent_references_for_event_txn(txn): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @@ -223,7 +223,7 @@ def _get_aggregation_groups_for_event_txn(txn): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -268,7 +268,7 @@ def _get_applicable_edit_txn(txn): if row: return row[0] - edit_id = yield self.runInteraction( + edit_id = yield self.db.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) @@ -318,7 +318,7 @@ def _get_if_user_has_annotated_event(txn): return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) @@ -352,7 +352,7 @@ def _handle_event_relations(self, txn, event): aggregation_key = relation.get("key") - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ def _handle_redaction(self, txn, redacted_event_id): redacted_event_id (str): The event that was redacted. """ - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index b7f9024811df..0148be20d36a 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -29,6 +29,7 @@ from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -53,7 +54,7 @@ def get_room(self, room_id): Returns: A dict containing the room information, or None if the room is unknown. """ - return self._simple_select_one( + return self.db.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -62,7 +63,7 @@ def get_room(self, room_id): ) def get_public_room_ids(self): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -119,7 +120,7 @@ def _count_public_rooms_txn(txn): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.runInteraction("count_public_rooms", _count_public_rooms_txn) + return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) @defer.inlineCallbacks def get_largest_public_rooms( @@ -252,21 +253,21 @@ def get_largest_public_rooms( def _get_largest_public_rooms_txn(txn): txn.execute(sql, query_args) - results = self.cursor_to_dict(txn) + results = self.db.cursor_to_dict(txn) if not forwards: results.reverse() return results - ret_val = yield self.runInteraction( + ret_val = yield self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) defer.returnValue(ret_val) @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -287,7 +288,7 @@ def get_ratelimit_for_user(self, user_id): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self._simple_select_one( + row = yield self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -329,9 +330,9 @@ def get_retention_policy_for_room_txn(txn): (room_id,), ) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) @@ -360,13 +361,13 @@ def get_retention_policy_for_room_txn(txn): defer.returnValue(row) -class RoomStore(RoomWorkerStore, SearchStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) +class RoomBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, ) @@ -395,7 +396,7 @@ def _background_insert_retention_txn(txn): (last_room, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True @@ -407,7 +408,7 @@ def _background_insert_retention_txn(txn): ev = json.loads(row["json"]) retention_policy = json.dumps(ev["content"]) - self._simple_insert_txn( + self.db.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -420,7 +421,7 @@ def _background_insert_retention_txn(txn): logger.info("Inserted %d rows into room_retention", len(rows)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) @@ -429,15 +430,22 @@ def _background_insert_retention_txn(txn): else: return False - end = yield self.runInteraction( + end = yield self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - yield self._end_background_update("insert_room_retention") + yield self.db.updates._end_background_update("insert_room_retention") defer.returnValue(batch_size) + +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -453,7 +461,7 @@ def store_room(self, room_id, room_creator_user_id, is_public): try: def store_room_txn(txn, next_id): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "rooms", { @@ -463,7 +471,7 @@ def store_room_txn(txn, next_id): }, ) if is_public: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -474,7 +482,7 @@ def store_room_txn(txn, next_id): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction("store_room_txn", store_room_txn, next_id) + yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -482,14 +490,14 @@ def store_room_txn(txn, next_id): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self._simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -507,7 +515,7 @@ def set_room_is_public_txn(txn, next_id): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -520,7 +528,7 @@ def set_room_is_public_txn(txn, next_id): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @@ -547,7 +555,7 @@ def set_room_is_public_appservice( def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -560,7 +568,7 @@ def set_room_is_public_appservice_txn(txn, next_id): # We've already inserted, nothing to do. return else: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -570,7 +578,7 @@ def set_room_is_public_appservice_txn(txn, next_id): }, ) - entries = self._simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -588,7 +596,7 @@ def set_room_is_public_appservice_txn(txn, next_id): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -601,7 +609,7 @@ def set_room_is_public_appservice_txn(txn, next_id): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -618,7 +626,7 @@ def f(txn): row = txn.fetchone() return row[0] or 0 - return self.runInteraction("get_rooms", f) + return self.db.runInteraction("get_rooms", f) def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: @@ -652,7 +660,7 @@ def _store_retention_policy_for_room_txn(self, txn, event): # Ignore the event if one of the value isn't an integer. return - self._simple_insert_txn( + self.db.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -671,7 +679,7 @@ def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( + return self.db.simple_insert( table="event_reports", values={ "id": next_id, @@ -704,7 +712,9 @@ def get_all_new_public_rooms(txn): if prev_id == current_id: return defer.succeed([]) - return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) @defer.inlineCallbacks def block_room(self, room_id, user_id): @@ -717,14 +727,14 @@ def block_room(self, room_id, user_id): Returns: Deferred """ - yield self._simple_upsert( + yield self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.runInteraction( + yield self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, @@ -755,7 +765,9 @@ def _get_media_mxcs_in_room_txn(txn): return local_media_mxcs, remote_media_mxcs - return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + return self.db.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines @@ -794,7 +806,7 @@ def _quarantine_media_in_room_txn(txn): return total_media_quarantined - return self.runInteraction( + return self.db.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -899,7 +911,7 @@ def get_rooms_for_retention_period_in_range_txn(txn): txn.execute(sql, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) rooms_dict = {} for row in rows: @@ -915,7 +927,7 @@ def get_rooms_for_retention_period_in_range_txn(txn): txn.execute(sql) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) # If a room isn't already in the dict (i.e. it doesn't have a retention # policy in its state), add it with a null policy. @@ -928,7 +940,7 @@ def get_rooms_for_retention_period_in_range_txn(txn): return rooms_dict - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 2af24a20b705..92e3b9c512f7 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Iterable, List from six import iteritems, itervalues @@ -25,9 +26,13 @@ from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + make_in_list_sql_clause, +) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -50,8 +55,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -115,7 +120,7 @@ def _transact(txn): txn.execute(query) return list(txn)[0][0] - count = yield self.runInteraction("get_known_servers", _transact) + count = yield self.db.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -127,7 +132,7 @@ def _check_safe_current_state_events_membership_updated_txn(self, txn): membership column is up to date """ - pending_update = self._simple_select_one_txn( + pending_update = self.db.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -143,7 +148,7 @@ def _check_safe_current_state_events_membership_updated_txn(self, txn): 15.0, run_as_background_process, "_check_safe_current_state_events_membership_updated", - self.runInteraction, + self.db.runInteraction, "_check_safe_current_state_events_membership_updated", self._check_safe_current_state_events_membership_updated_txn, ) @@ -160,7 +165,7 @@ def get_hosts_in_room(self, room_id, cache_context): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -268,7 +273,7 @@ def _get_room_summary_txn(txn): return res - return self.runInteraction("get_room_summary", _get_room_summary_txn) + return self.db.runInteraction("get_room_summary", _get_room_summary_txn) def _get_user_counts_in_room_txn(self, txn, room_id): """ @@ -338,7 +343,7 @@ def get_rooms_for_user_where_membership_is(self, user_id, membership_list): if not membership_list: return defer.succeed(None) - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, user_id, @@ -391,7 +396,7 @@ def _get_rooms_for_user_where_membership_is_txn( ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] if do_invite: sql = ( @@ -411,7 +416,7 @@ def _get_rooms_for_user_where_membership_is_txn( stream_ordering=r["stream_ordering"], membership=Membership.INVITE, ) - for r in self.cursor_to_dict(txn) + for r in self.db.cursor_to_dict(txn) ) return results @@ -602,7 +607,7 @@ def _get_joined_profiles_from_event_ids(self, event_ids): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -642,7 +647,7 @@ def is_host_joined(self, room_id, host): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -682,7 +687,7 @@ def was_host_joined(self, room_id, host): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -752,7 +757,7 @@ def f(txn): rows = txn.fetchall() return rows[0][0] - count = yield self.runInteraction("did_forget_membership", f) + count = yield self.db.runInteraction("did_forget_membership", f) return count == 0 @cached() @@ -789,7 +794,7 @@ def _get_forgotten_rooms_for_user_txn(txn): txn.execute(sql, (user_id,)) return set(row[0] for row in txn if row[1] == 0) - return self.runInteraction( + return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -804,7 +809,7 @@ def get_rooms_user_has_been_in(self, user_id): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self._simple_select_onecol( + room_ids = yield self.db.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -813,18 +818,34 @@ def get_rooms_user_has_been_in(self, user_id): return set(room_ids) + def get_membership_from_event_ids( + self, member_event_ids: Iterable[str] + ) -> List[dict]: + """Get user_id and membership of a set of event IDs. + """ + + return self.db.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=("user_id", "membership", "event_id"), + keyvalues={}, + batch_size=500, + desc="get_membership_from_event_ids", + ) + -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( +class RoomMemberBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -857,7 +878,7 @@ def add_membership_profile_txn(txn): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -892,18 +913,20 @@ def add_membership_profile_txn(txn): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + yield self.db.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) return result @@ -942,7 +965,7 @@ def _background_current_state_membership_txn(txn, last_processed_room): last_processed_room = next_room - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -954,26 +977,28 @@ def _background_current_state_membership_txn(txn, last_processed_room): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.runInteraction( + row_count, finished = yield self.db.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, ) if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + yield self.db.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) return row_count class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1011,7 +1036,7 @@ def _store_room_members_txn(self, txn, events, backfilled): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_invites", values={ @@ -1051,7 +1076,7 @@ def f(txn, stream_ordering): txn.execute(sql, (stream_ordering, True, room_id, user_id)) with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) + yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" @@ -1074,7 +1099,7 @@ def f(txn): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.runInteraction("forget_membership", f) + return self.db.runInteraction("forget_membership", f) class _JoinedHostsCache(object): diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 000000000000..81a36a8b1dc9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql index fe51b02309b7..ea95db0ed717 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql @@ -14,4 +14,3 @@ */ ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; -CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql index 77a5eca49943..49ce35d79472 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql @@ -14,7 +14,9 @@ */ ALTER TABLE redactions ADD COLUMN received_ts BIGINT; -CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored; INSERT INTO background_updates (update_name, progress_json) VALUES ('redactions_received_ts', '{}'); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_have_censored_ts_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql new file mode 100644 index 000000000000..b7550f6f4e70 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS redactions_have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql index 27a96123e377..5c5fffcafb16 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql @@ -40,7 +40,8 @@ CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( signature TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); +-- replaced by the index created in signing_keys_nonunique_signatures.sql +-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); -- stream of user signature updates CREATE TABLE IF NOT EXISTS user_signature_stream ( diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql new file mode 100644 index 000000000000..0aa90ebf0c68 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The cross-signing signatures index should not be a unique index, because a + * user may upload multiple signatures for the same target user. The previous + * index was unique, so delete it if it's there and create a new non-unique + * index. */ + +DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT +EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index a28a194cfe9b..dfb46ee0f8ad 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -24,9 +24,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -37,23 +37,23 @@ ) -class SearchBackgroundUpdateStore(BackgroundUpdateStore): +class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -62,9 +62,11 @@ def __init__(self, db_conn, hs): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + self.db.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @@ -94,7 +96,7 @@ def reindex_search_txn(txn): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -154,18 +156,18 @@ def reindex_search_txn(txn): "rows_inserted": rows_inserted + len(event_search_rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) return len(event_search_rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) return result @@ -207,9 +209,11 @@ def create_index(conn): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) return 1 @defer.inlineCallbacks @@ -238,14 +242,14 @@ def create_index(conn): ) conn.set_session(autocommit=False) - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.runInteraction( + yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -275,18 +279,20 @@ def reindex_search_txn(txn): "have_added_indexes": True, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) return len(rows), True - num_rows, finished = yield self.runInteraction( + num_rows, finished = yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME + ) return num_rows @@ -338,8 +344,8 @@ def store_search_entries_txn(self, txn, entries): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) def store_event_search_txn(self, txn, event, key, value): """Add event to the search table @@ -442,7 +448,9 @@ def search_msgs(self, room_ids, search_term, keys): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_msgs", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -461,8 +469,8 @@ def search_msgs(self, room_ids, search_term, keys): count_sql += " GROUP BY room_id" - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -592,7 +600,9 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None args.append(limit) - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_rooms", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -606,8 +616,8 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None count_sql += " GROUP BY room_id" - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -692,7 +702,7 @@ def f(txn): return highlight_words - return self.runInteraction("_find_highlights", f) + return self.db.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 556191b76f48..563216b63c92 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -48,7 +48,7 @@ def f(txn): for event_id in event_ids } - return self.runInteraction("get_event_reference_hashes", f) + return self.db.runInteraction("get_event_reference_hashes", f) @defer.inlineCallbacks def add_event_hashes(self, event_ids): @@ -98,4 +98,4 @@ def _store_event_reference_hashes_txn(self, txn, events): } ) - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 6a90daea31bc..9ef7b48c740d 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -27,8 +27,8 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.util.caches import get_cache_factor_for, intern_string @@ -89,7 +89,7 @@ def _count_state_group_hops_txn(self, txn, state_group): count = 0 while next_group: - next_group = self._simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ def _get_state_groups_from_groups_txn( ): break - next_group = self._simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -214,8 +214,8 @@ class StateGroupWorkerStore( STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -348,7 +348,9 @@ def _get_current_state_ids_txn(txn): (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn } - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) + return self.db.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn + ) # FIXME: how should this be cached? def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): @@ -392,7 +394,7 @@ def _get_filtered_current_state_ids_txn(txn): return results - return self.runInteraction( + return self.db.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @@ -431,7 +433,7 @@ def get_state_group_delta(self, state_group): """ def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( + prev_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +444,7 @@ def _get_state_group_delta_txn(txn): if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self._simple_select_list_txn( + delta_ids = self.db.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -454,7 +456,9 @@ def _get_state_group_delta_txn(txn): {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -540,7 +544,7 @@ def _get_state_groups_from_groups(self, groups, state_filter): chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -644,7 +648,7 @@ def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +665,7 @@ def _get_state_group_for_event(self, event_id): def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +906,7 @@ def _store_state_group_txn(txn): state_group = self.database_engine.get_next_state_group_id(txn) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +915,7 @@ def _store_state_group_txn(txn): # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self._simple_select_one_onecol_txn( + is_in_db = self.db.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +930,13 @@ def _store_state_group_txn(txn): potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +951,7 @@ def _store_state_group_txn(txn): ], ) else: - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -993,7 +997,7 @@ def _store_state_group_txn(txn): return state_group - return self.runInteraction("store_state_group", _store_state_group_txn) + return self.db.runInteraction("store_state_group", _store_state_group_txn) @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): @@ -1007,7 +1011,7 @@ def get_referenced_state_groups(self, state_groups): referenced. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1019,32 +1023,30 @@ def get_referenced_state_groups(self, state_groups): return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", @@ -1065,7 +1067,7 @@ def _background_deduplicate_state(self, progress, batch_size): batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self._execute( + rows = yield self.db.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1137,13 @@ def reindex_txn(txn): if prev_state.get(key, None) != value } - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1152,13 @@ def reindex_txn(txn): }, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1177,18 +1179,18 @@ def reindex_txn(txn): "max_group": max_group, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) return False, batch_size - finished, result = yield self.runInteraction( + finished, result = yield self.db.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self._end_background_update( + yield self.db.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) @@ -1218,9 +1220,9 @@ def reindex_txn(conn): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) return 1 @@ -1244,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) def _store_event_state_mappings_txn( self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] @@ -1263,7 +1265,7 @@ def _store_event_state_mappings_txn( state_groups[event.event_id] = context.state_group - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 28f33ec18fae..12c982cb2688 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -98,14 +98,14 @@ def get_current_state_deltas_txn(txn): ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.cursor_to_dict(txn) + return clipped_stream_id, self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( + return self.db.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, @@ -113,7 +113,7 @@ def _get_max_stream_id_in_current_state_deltas_txn(self, txn): ) def get_max_stream_id_in_current_state_deltas(self): - return self.runInteraction( + return self.db.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 45b3de7d56c4..7bc186e9a1c8 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -58,8 +59,8 @@ class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() @@ -68,17 +69,17 @@ def __init__(self, db_conn, hs): self.stats_delta_processing_lock = DeferredLock() - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") + self.db.updates.register_noop_background_update("populate_stats_cleanup") + self.db.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -102,7 +103,7 @@ def _populate_stats_process_users(self, progress, batch_size): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -117,22 +118,22 @@ def _get_next_batch(txn): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: yield self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.runInteraction( + yield self.db.runInteraction( "populate_stats_process_users", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -145,7 +146,7 @@ def _populate_stats_process_rooms(self, progress, batch_size): This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -160,22 +161,22 @@ def _get_next_batch(txn): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: yield self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.runInteraction( + yield self.db.runInteraction( "_populate_stats_process_rooms", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", progress, ) @@ -186,7 +187,7 @@ def get_stats_positions(self): """ Returns the stats processor positions. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +216,7 @@ def update_room_state(self, room_id, fields): if field and "\0" in field: fields[col] = None - return self._simple_upsert( + return self.db.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -236,7 +237,7 @@ def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): Deferred[list[dict]], where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.runInteraction( + return self.db.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -257,14 +258,14 @@ def _get_statistics_for_subject_txn( ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self._simple_select_list_paginate_txn( + slice_list = self.db.simple_select_list_paginate_txn( txn, table + "_historical", - {id_col: stats_id}, "end_ts", start, size, retcols=selected_columns + ["bucket_size", "end_ts"], + keyvalues={id_col: stats_id}, order_direction="DESC", ) @@ -282,7 +283,7 @@ def get_room_stats_state(self, room_id): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self._simple_select_one( + return self.db.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +309,7 @@ def get_earliest_token_for_stats(self, stats_type, id): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,14 +345,14 @@ def _bulk_update_stats_delta_txn(txn): complete_with_stream_id=stream_id, ) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.runInteraction( + return self.db.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) @@ -382,7 +383,7 @@ def update_stats_delta( Does not work with per-slice fields. """ - return self.runInteraction( + return self.db.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -517,17 +518,17 @@ def _upsert_with_additive_relatives_txn( else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( + current_row = self.db.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) + self.db.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) + self.db.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +615,11 @@ def _upsert_copy_from_table_with_additive_relatives_txn( txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( + src_row = self.db.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( + dest_current_row = self.db.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +635,11 @@ def _upsert_copy_from_table_with_additive_relatives_txn( **src_row, **additive_relatives, } - self._simple_insert_txn(txn, into_table, merged_dict) + self.db.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -652,7 +653,7 @@ def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): changes. """ - return self.runInteraction( + return self.db.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -735,7 +736,7 @@ def _calculate_and_set_initial_state_for_room(self, room_id): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -791,7 +792,7 @@ def _fetch_current_state_stats(txn): current_state_events_count, users_in_room, pos, - ) = yield self.runInteraction( + ) = yield self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) @@ -866,7 +867,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.runInteraction( + joined_rooms, pos = yield self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 9ae4a913a181..140da8dad686 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +47,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -248,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self._get_cache_dict( + event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -397,7 +401,7 @@ def f(txn): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.runInteraction("get_room_events_stream_for_room", f) + rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -447,7 +451,7 @@ def f(txn): return rows - rows = yield self.runInteraction("get_membership_changes_for_user", f) + rows = yield self.db.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -508,7 +512,7 @@ def get_recent_event_ids_for_room(self, room_id, limit, end_token): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -545,7 +549,7 @@ def _f(txn): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.runInteraction("get_room_event_after_stream_ordering", _f) + return self.db.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): @@ -559,7 +563,7 @@ def get_room_events_max_id(self, room_id=None): if room_id is None: return "s%d" % (token,) else: - topo = yield self.runInteraction( + topo = yield self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) @@ -573,7 +577,7 @@ def get_stream_token_for_event(self, event_id): Returns: A deferred "s%d" stream token. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -586,7 +590,7 @@ def get_topological_token_for_event(self, event_id): Returns: A deferred "t%d-%d" topological token. """ - return self._simple_select_one( + return self.db.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -610,7 +614,7 @@ def get_max_topological_token(self, room_id, stream_key): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self._execute( + return self.db.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -664,7 +668,7 @@ def get_events_around( dict """ - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -706,7 +710,7 @@ def _get_events_around_txn( dict """ - results = self._simple_select_one_txn( + results = self.db.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -785,7 +789,7 @@ def get_all_new_events_stream_txn(txn): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) @@ -794,7 +798,7 @@ def get_all_new_events_stream_txn(txn): return upper_bound, events def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -802,7 +806,7 @@ def get_federation_out_pos(self, typ): ) def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( + return self.db.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, @@ -953,7 +957,7 @@ def paginate_room_events( if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index aa2433971712..2aa1bafd48d3 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ def get_tags_for_user(self, user_id): tag strings to tag content. """ - deferred = self._simple_select_list( + deferred = self.db.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -78,7 +78,7 @@ def get_all_updated_tags_txn(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.runInteraction( + tag_ids = yield self.db.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -98,7 +98,7 @@ def get_tag_content(txn, tag_ids): batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.runInteraction( + tags = yield self.db.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], @@ -135,7 +135,9 @@ def get_updated_tags_txn(txn): if not changed: return {} - room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) + room_ids = yield self.db.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) results = {} if room_ids: @@ -153,7 +155,7 @@ def get_tags_for_room(self, user_id, room_id): Returns: A deferred list of string tags. """ - return self._simple_select_list( + return self.db.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -178,7 +180,7 @@ def add_tag_to_room(self, user_id, room_id, tag, content): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -187,7 +189,7 @@ def add_tag_txn(txn, next_id): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("add_tag", add_tag_txn, next_id) + yield self.db.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -210,7 +212,7 @@ def remove_tag_txn(txn, next_id): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 01b1be5e148a..5b07c2fbc0c6 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -24,6 +24,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache # py2 sqlite has buffer hardcoded as only binary type, so we must use it, @@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) @@ -77,7 +78,7 @@ def get_received_txn_response(self, transaction_id, origin): this transaction or a 2-tuple of (int, dict) """ - return self.runInteraction( + return self.db.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -85,7 +86,7 @@ def get_received_txn_response(self, transaction_id, origin): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +120,7 @@ def set_received_txn_response(self, transaction_id, origin, code, response_dict) response_json (str) """ - return self._simple_insert( + return self.db.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -148,7 +149,7 @@ def get_destination_retry_timings(self, destination): if result is not SENTINEL: return result - result = yield self.runInteraction( + result = yield self.db.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -160,7 +161,7 @@ def get_destination_retry_timings(self, destination): return result def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -187,7 +188,7 @@ def set_destination_retry_timings( """ self._destination_retry_cache.pop(destination, None) - return self.runInteraction( + return self.db.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -227,7 +228,7 @@ def _set_destination_retry_timings( # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self._simple_select_one_txn( + prev_row = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +237,7 @@ def _set_destination_retry_timings( ) if not prev_row: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +248,7 @@ def _set_destination_retry_timings( }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -270,4 +271,6 @@ def _cleanup_transactions(self): def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) + return self.db.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 652abe0e6a03..90c180ec6d86 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -19,9 +19,9 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -32,30 +32,30 @@ TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) @@ -85,7 +85,7 @@ def _make_staging_area(txn): """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,15 +100,17 @@ def _make_staging_area(txn): txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - yield self._end_background_update("populate_user_directory_createtables") + yield self.db.updates._end_background_update( + "populate_user_directory_createtables" + ) return 1 @defer.inlineCallbacks @@ -116,7 +118,7 @@ def _populate_user_directory_cleanup(self, progress, batch_size): """ Update the user directory stream position, then clean up the old tables. """ - position = yield self._simple_select_one_onecol( + position = yield self.db.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -126,11 +128,11 @@ def _delete_staging_area(txn): txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self._end_background_update("populate_user_directory_cleanup") + yield self.db.updates._end_background_update("populate_user_directory_cleanup") return 1 @defer.inlineCallbacks @@ -170,13 +172,15 @@ def _get_next_batch(txn): return rooms_to_work_on - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") + yield self.db.updates._end_background_update( + "populate_user_directory_process_rooms" + ) return 1 logger.info( @@ -243,12 +247,12 @@ def _get_next_batch(txn): to_insert.clear() # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,7 +271,9 @@ def _populate_user_directory_process_users(self, progress, batch_size): If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 def _get_next_batch(txn): @@ -291,13 +297,15 @@ def _get_next_batch(txn): return users_to_work_on - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 logger.info( @@ -312,12 +320,12 @@ def _get_next_batch(txn): ) # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) @@ -361,7 +369,7 @@ def update_profile_in_user_dir(self, user_id, display_name, avatar_url): """ def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( + new_entry = self.db.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +443,7 @@ def _update_profile_in_user_dir_txn(txn): ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -448,7 +456,7 @@ def _update_profile_in_user_dir_txn(txn): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction( + return self.db.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) @@ -462,7 +470,7 @@ def add_users_who_share_private_room(self, room_id, user_id_tuples): """ def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -474,7 +482,7 @@ def _add_users_who_share_room_txn(txn): value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -489,7 +497,7 @@ def add_users_in_public_rooms(self, room_id, user_ids): def _add_users_in_public_rooms_txn(txn): - self._simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -498,7 +506,7 @@ def _add_users_in_public_rooms_txn(txn): value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) @@ -513,13 +521,13 @@ def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.runInteraction( + return self.db.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @cached() def get_user_in_directory(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +536,7 @@ def get_user_in_directory(self, user_id): ) def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( + return self.db.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -542,47 +550,47 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) @defer.inlineCallbacks def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self._simple_select_onecol( + user_ids_share_pub = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self._simple_select_onecol( + user_ids_share_priv = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,23 +613,23 @@ def remove_user_who_share_room(self, user_id, room_id): """ def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) @@ -636,14 +644,14 @@ def get_user_dir_rooms_user_is_in(self, user_id): Returns: list: user_id """ - rows = yield self._simple_select_onecol( + rows = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self._simple_select_onecol( + pub_rows = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +682,14 @@ def get_rooms_in_common_for_users(self, user_id, other_user_id): ) f2 USING (room_id) """ - rows = yield self._execute( + rows = yield self.db.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,8 +794,8 @@ def search_user_dir(self, user_id, search_term, limit): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args + results = yield self.db.execute( + "search_user_dir", self.db.cursor_to_dict, sql, *args ) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index aa4f0da5f01b..af8025bc17a0 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ def is_user_erased(self, user_id): Returns: Deferred[bool]: True if the user has requested erasure """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ def are_users_erased(self, user_ids): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -88,4 +88,4 @@ def f(txn): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.runInteraction("mark_user_erased", f) + return self.db.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/database.py b/synapse/storage/database.py new file mode 100644 index 000000000000..ec19ae1d9db0 --- /dev/null +++ b/synapse/storage/database.py @@ -0,0 +1,1490 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys +import time +from typing import Iterable, Tuple + +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range + +from prometheus_client import Histogram + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdater +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.util.stringutils import exception_to_unicode + +# import a function which will return a monotonic time, in seconds +try: + # on python 3, use time.monotonic, since time.clock can go backwards + from time import monotonic as monotonic_time +except ImportError: + # ... but python 2 doesn't have it + from time import clock as monotonic_time + +logger = logging.getLogger(__name__) + +try: + MAX_TXN_ID = sys.maxint - 1 +except AttributeError: + # python 3 does not have a maximum int value + MAX_TXN_ID = 2 ** 63 - 1 + +sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") +perf_logger = logging.getLogger("synapse.storage.TIME") + +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") + +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) + + +# Unique indexes which have been added in background updates. Maps from table name +# to the name of the background update which added the unique index to that table. +# +# This is used by the upsert logic to figure out which tables are safe to do a proper +# UPSERT on: until the relevant background update has completed, we +# have to emulate an upsert by locking the table. +# +UNIQUE_INDEX_BACKGROUND_UPDATES = { + "user_ips": "user_ips_device_unique_index", + "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", + "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", + "event_search": "event_search_event_id_idx", +} + + +class LoggingTransaction(object): + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging and metrics to the .execute() + method. + + Args: + txn: The database transcation object to wrap. + name (str): The name of this transactions for logging. + database_engine (Sqlite3Engine|PostgresEngine) + after_callbacks(list|None): A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks(list|None): A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ + + __slots__ = [ + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", + ] + + def __init__( + self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + ): + object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) + object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) + + def call_after(self, callback, *args, **kwargs): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) + + def __getattr__(self, name): + return getattr(self.txn, name) + + def __setattr__(self, name, value): + setattr(self.txn, name, value) + + def __iter__(self): + return self.txn.__iter__() + + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch + + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + + def execute(self, sql, *args): + self._do_execute(self.txn.execute, sql, *args) + + def executemany(self, sql, *args): + self._do_execute(self.txn.executemany, sql, *args) + + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] {%s} %s", self.name, sql) + + sql = self.database_engine.convert_param_style(sql) + if args: + try: + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + except Exception: + # Don't let logging failures stop SQL from working + pass + + start = time.time() + + try: + return func(sql, *args) + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise + finally: + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) + + +class PerformanceCounters(object): + def __init__(self): + self.current_counters = {} + self.previous_counters = {} + + def update(self, key, duration_secs): + count, cum_time = self.current_counters.get(key, (0, 0)) + count += 1 + cum_time += duration_secs + self.current_counters[key] = (count, cum_time) + + def interval(self, interval_duration_secs, limit=3): + counters = [] + for name, (count, cum_time) in iteritems(self.current_counters): + prev_count, prev_time = self.previous_counters.get(name, (0, 0)) + counters.append( + ( + (cum_time - prev_time) / interval_duration_secs, + count - prev_count, + name, + ) + ) + + self.previous_counters = dict(self.current_counters) + + counters.sort(reverse=True) + + top_n_counters = ", ".join( + "%s(%d): %.3f%%" % (name, count, 100 * ratio) + for ratio, count, name in counters[:limit] + ) + + return top_n_counters + + +class Database(object): + """Wraps a single physical database and connection pool. + + A single database may be used by multiple data stores. + """ + + _TXN_ID = 0 + + def __init__(self, hs): + self.hs = hs + self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() + + self.updates = BackgroundUpdater(hs, self) + + self._previous_txn_total_time = 0 + self._current_txn_total_time = 0 + self._previous_loop_ts = 0 + + # TODO(paul): These can eventually be removed once the metrics code + # is running in mainline, and we have some nice monitoring frontends + # to watch it + self._txn_perf_counters = PerformanceCounters() + + self.engine = hs.database_engine + + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + + if self.engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait 15 sec and check again. + """ + updates = yield self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): + if update_name not in updates: + logger.debug("Now safe to upsert in %s", table) + self._unsafe_to_upsert_tables.discard(table) + + # If there's any updates still running, reschedule to run. + if updates: + self._clock.call_later( + 15.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def start_profiling(self): + self._previous_loop_ts = monotonic_time() + + def loop(): + curr = self._current_txn_total_time + prev = self._previous_txn_total_time + self._previous_txn_total_time = curr + + time_now = monotonic_time() + time_then = self._previous_loop_ts + self._previous_loop_ts = time_now + + duration = time_now - time_then + ratio = (curr - prev) / duration + + top_three_counters = self._txn_perf_counters.interval(duration, limit=3) + + perf_logger.info( + "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters + ) + + self._clock.looping_call(loop, 10000) + + def new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): + start = monotonic_time() + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) + + name = "%s-%x" % (desc, txn_id) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.engine, + after_callbacks, + exception_callbacks, + ) + try: + r = func(cursor, *args, **kwargs) + conn.commit() + return r + except self.engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warning( + "[TXN OPERROR] {%s} %s %d/%d", + name, + exception_to_unicode(e), + i, + N, + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) + ) + continue + raise + except self.engine.module.DatabaseError as e: + if self.engine.is_deadlock(e): + logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", + name, + exception_to_unicode(e1), + ) + continue + raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = monotonic_time() + duration = end - start + + LoggingContext.current_context().add_database_transaction(duration) + + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, duration) + sql_txn_timer.labels(desc).observe(duration) + + @defer.inlineCallbacks + def runInteraction(self, desc, func, *args, **kwargs): + """Starts a transaction on the database and runs a given function + + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + after_callbacks = [] + exception_callbacks = [] + + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warning("Starting db txn '%s' from sentinel context", desc) + + try: + result = yield self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) + + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + return result + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warning( + "Starting db connection from sentinel context: metrics will be lost" + ) + parent_context = None + + start_time = monotonic_time() + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = monotonic_time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) + + if self.engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + return func(conn, *args, **kwargs) + + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) + + return result + + @staticmethod + def cursor_to_dict(cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(intern(str(column[0])) for column in cursor.description) + results = list(dict(zip(col_headers, row)) for row in cursor) + return results + + def execute(self, desc, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + + def interaction(txn): + txn.execute(query, args) + if decoder: + return decoder(txn) + else: + return txn.fetchall() + + return self.runInteraction(desc, interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + @defer.inlineCallbacks + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True + """ + try: + yield self.runInteraction(desc, self.simple_insert_txn, table, values) + except self.engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise + return False + return True + + @staticmethod + def simple_insert_txn(txn, table, values): + keys, vals = zip(*values.items()) + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute(sql, vals) + + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + + @staticmethod + def simple_insert_many_txn(txn, table, values): + if not values: + return + + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) + + for k in keys: + if k != keys[0]: + raise RuntimeError("All items must have the same keys") + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.executemany(sql, vals) + + @defer.inlineCallbacks + def simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="simple_upsert", + lock=True, + ): + """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key columns and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self.simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, + ) + return result + except self.engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warning( + "IntegrityError when upserting into %s; retrying: %s", table, e + ) + + def simple_upsert_txn( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + None or bool: Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + return self.simple_upsert_txn_native_upsert( + txn, table, keyvalues, values, insertion_values=insertion_values + ) + else: + return self.simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + bool: Return True if a new entry was created, False if an existing + one was updated. + """ + # We need to lock the table :(, unless we're *really* careful + if lock: + self.engine.lock_table(txn, table) + + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(values.values()) + list(keyvalues.values()) + + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False + + # We didn't find any existing rows, so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True + + def simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(insertion_values) + + if not values: + latter = "NOTHING" + else: + allvalues.update(values) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) + + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + latter, + ) + txn.execute(sql, list(allvalues.values())) + + def simple_upsert_many_txn( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + return self.simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self.simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def simple_upsert_many_txn_emulated( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def simple_upsert_many_txn_native_upsert( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, using batching where possible. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + allnames = [] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning multiple columns from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self.runInteraction( + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + ) + + def simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="simple_select_one_onecol", + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + return self.runInteraction( + desc, + self.simple_select_one_onecol_txn, + table, + keyvalues, + retcol, + allow_none=allow_none, + ) + + @classmethod + def simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): + ret = cls.simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, retcol=retcol + ) + + if ret: + return ret[0] + else: + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + @staticmethod + def simple_select_onecol_txn(txn, table, keyvalues, retcol): + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} + + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) + + return [r[0] for r in txn] + + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" + ): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict|None): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + return self.runInteraction( + desc, self.simple_select_onecol_txn, table, keyvalues, retcol + ) + + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, self.simple_select_list_txn, table, keyvalues, retcols + ) + + @classmethod + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + txn.execute(sql, list(keyvalues.values())) + else: + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + txn.execute(sql) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="simple_select_many_batch", + batch_size=100, + ): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] + + if not iterable: + return results + + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + + chunks = [ + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) + ] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self.simple_select_many_txn, + table, + column, + chunk, + keyvalues, + retcols, + ) + + results.extend(rows) + + return results + + @classmethod + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + if not iterable: + return [] + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return cls.cursor_to_dict(txn) + + def simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, self.simple_update_txn, table, keyvalues, updatevalues + ) + + @staticmethod + def simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) + + return txn.rowcount + + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" + ): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self.runInteraction( + desc, self.simple_update_one_txn, table, keyvalues, updatevalues + ) + + @classmethod + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) + + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + @staticmethod + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(select_sql, list(keyvalues.values())) + row = txn.fetchone() + + if not row: + if allow_none: + return None + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + return dict(zip(retcols, row)) + + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + + @staticmethod + def simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + if txn.rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + + @staticmethod + def simple_delete_txn(txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount + + def simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted + """ + if not iterable: + return 0 + + sql = "DELETE FROM %s" % table + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + txn.execute(sql, values) + + return txn.rowcount + + def get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - %(limit)s" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + "limit": limit, + } + + sql = self.engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + + cache = {row[0]: int(row[1]) for row in txn} + + txn.close() + + if cache: + min_val = min(itervalues(cache)) + else: + min_val = max_value + + return cache, min_val + + def simple_select_list_paginate( + self, + table, + orderby, + start, + limit, + retcols, + filters=None, + keyvalues=None, + order_direction="ASC", + desc="simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self.simple_select_list_paginate_txn, + table, + orderby, + start, + limit, + retcols, + filters=filters, + keyvalues=keyvalues, + order_direction=order_direction, + ) + + @classmethod + def simple_select_list_paginate_txn( + cls, + txn, + table, + orderby, + start, + limit, + retcols, + filters=None, + keyvalues=None, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to + select attributes with exact matches. All constraints are joined together + using 'AND'. + + Args: + txn : Transaction object + table (str): the table name + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + + where_clause = "WHERE " if filters or keyvalues else "" + arg_list = [] + if filters: + where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) + arg_list += list(filters.values()) + where_clause += " AND " if filters and keyvalues else "" + if keyvalues: + where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues) + arg_list += list(keyvalues.values()) + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, arg_list + [limit, start]) + + return cls.cursor_to_dict(txn) + + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, self.simple_search_list_txn, table, term, col, retcols + ) + + @classmethod + def simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c76..000000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 3286804322e1..7b1845546969 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from functools import wraps @@ -64,12 +65,22 @@ def measure_func(name=None): def wrapper(func): block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + else: + + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = yield func(self, *args, **kwargs) + return r return measured_func @@ -80,72 +91,48 @@ class Measure(object): __slots__ = [ "clock", "name", - "start_context", + "_logging_context", "start", - "created_context", - "start_usage", ] def __init__(self, clock, name): self.clock = clock self.name = name - self.start_context = None + self._logging_context = None self.start = None - self.created_context = False def __enter__(self): - self.start = self.clock.time() - self.start_context = LoggingContext.current_context() - if not self.start_context: - self.start_context = LoggingContext("Measure") - self.start_context.__enter__() - self.created_context = True - - self.start_usage = self.start_context.get_resource_usage() + if self._logging_context: + raise RuntimeError("Measure() objects cannot be re-used") + self.start = self.clock.time() + parent_context = LoggingContext.current_context() + self._logging_context = LoggingContext( + "Measure[%s]" % (self.name,), parent_context + ) + self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) def __exit__(self, exc_type, exc_val, exc_tb): - if isinstance(exc_type, Exception) or not self.start_context: - return - - in_flight.unregister((self.name,), self._update_in_flight) + if not self._logging_context: + raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start + usage = self._logging_context.get_resource_usage() - block_counter.labels(self.name).inc() - block_timer.labels(self.name).inc(duration) - - context = LoggingContext.current_context() - - if context != self.start_context: - logger.warning( - "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, - context, - self.name, - ) - return - - if not context: - logger.warning("Expected context. (%r)", self.name) - return + in_flight.unregister((self.name,), self._update_in_flight) + self._logging_context.__exit__(exc_type, exc_val, exc_tb) - current = context.get_resource_usage() - usage = current - self.start_usage try: + block_counter.labels(self.name).inc() + block_timer.labels(self.name).inc(duration) block_ru_utime.labels(self.name).inc(usage.ru_utime) block_ru_stime.labels(self.name).inc(usage.ru_stime) block_db_txn_count.labels(self.name).inc(usage.db_txn_count) block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warning( - "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current - ) - - if self.created_context: - self.start_context.__exit__(exc_type, exc_val, exc_tb) + logger.warning("Failed to save metrics! Usage: %s", usage) def _update_in_flight(self, metrics): """Gets called when processing in flight metrics diff --git a/synmark/__init__.py b/synmark/__init__.py new file mode 100644 index 000000000000..afe4fad8cb4e --- /dev/null +++ b/synmark/__init__.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from twisted.internet import epollreactor +from twisted.internet.main import installReactor + +from synapse.config.homeserver import HomeServerConfig +from synapse.util import Clock + +from tests.utils import default_config, setup_test_homeserver + + +async def make_homeserver(reactor, config=None): + """ + Make a Homeserver suitable for running benchmarks against. + + Args: + reactor: A Twisted reactor to run under. + config: A HomeServerConfig to use, or None. + """ + cleanup_tasks = [] + clock = Clock(reactor) + + if not config: + config = default_config("test") + + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config, "", "") + + hs = await setup_test_homeserver( + cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock + ) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor.db.updates, "do_next_background_update"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + + def cleanup(): + for i in cleanup_tasks: + i() + + return hs, clock.sleep, cleanup + + +def make_reactor(): + """ + Instantiate and install a Twisted reactor suitable for testing (i.e. not the + default global one). + """ + reactor = epollreactor.EPollReactor() + + if "twisted.internet.reactor" in sys.modules: + del sys.modules["twisted.internet.reactor"] + installReactor(reactor) + + return reactor diff --git a/synmark/__main__.py b/synmark/__main__.py new file mode 100644 index 000000000000..ac59befbd497 --- /dev/null +++ b/synmark/__main__.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from contextlib import redirect_stderr +from io import StringIO + +import pyperf +from synmark import make_reactor +from synmark.suites import SUITES + +from twisted.internet.defer import ensureDeferred +from twisted.logger import globalLogBeginner, textFileLogObserver +from twisted.python.failure import Failure + +from tests.utils import setupdb + + +def make_test(main): + """ + Take a benchmark function and wrap it in a reactor start and stop. + """ + + def _main(loops): + + reactor = make_reactor() + + file_out = StringIO() + with redirect_stderr(file_out): + + d = ensureDeferred(main(reactor, loops)) + + def on_done(_): + if isinstance(_, Failure): + _.printTraceback() + print(file_out.getvalue()) + reactor.stop() + return _ + + d.addBoth(on_done) + reactor.run() + + return d.result + + return _main + + +if __name__ == "__main__": + + def add_cmdline_args(cmd, args): + if args.log: + cmd.extend(["--log"]) + + runner = pyperf.Runner( + processes=3, min_time=2, show_name=True, add_cmdline_args=add_cmdline_args + ) + runner.argparser.add_argument("--log", action="store_true") + runner.parse_args() + + orig_loops = runner.args.loops + runner.args.inherit_environ = ["SYNAPSE_POSTGRES"] + + if runner.args.worker: + if runner.args.log: + globalLogBeginner.beginLoggingTo( + [textFileLogObserver(sys.__stdout__)], redirectStandardIO=False + ) + setupdb() + + for suite, loops in SUITES: + if loops: + runner.args.loops = loops + else: + runner.args.loops = orig_loops + loops = "auto" + runner.bench_time_func( + suite.__name__ + "_" + str(loops), make_test(suite.main), + ) diff --git a/synmark/suites/__init__.py b/synmark/suites/__init__.py new file mode 100644 index 000000000000..cfa3b0ba3823 --- /dev/null +++ b/synmark/suites/__init__.py @@ -0,0 +1,3 @@ +from . import logging + +SUITES = [(logging, 1000), (logging, 10000), (logging, None)] diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py new file mode 100644 index 000000000000..d8e4c7d58f26 --- /dev/null +++ b/synmark/suites/logging.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from io import StringIO + +from mock import Mock + +from pyperf import perf_counter +from synmark import make_homeserver + +from twisted.internet.defer import Deferred +from twisted.internet.protocol import ServerFactory +from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.protocols.basic import LineOnlyReceiver + +from synapse.logging._structured import setup_structured_logging + + +class LineCounter(LineOnlyReceiver): + + delimiter = b"\n" + + def __init__(self, *args, **kwargs): + self.count = 0 + super().__init__(*args, **kwargs) + + def lineReceived(self, line): + self.count += 1 + + if self.count >= self.factory.wait_for and self.factory.on_done: + on_done = self.factory.on_done + self.factory.on_done = None + on_done.callback(True) + + +async def main(reactor, loops): + """ + Benchmark how long it takes to send `loops` messages. + """ + servers = [] + + def protocol(): + p = LineCounter() + servers.append(p) + return p + + logger_factory = ServerFactory.forProtocol(protocol) + logger_factory.wait_for = loops + logger_factory.on_done = Deferred() + port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") + + hs, wait, cleanup = await make_homeserver(reactor) + + errors = StringIO() + publisher = LogPublisher() + mock_sys = Mock() + beginner = LogBeginner( + publisher, errors, mock_sys, warnings, initialBufferSize=loops + ) + + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": port.getHost().port, + "maximum_buffer": 100, + } + }, + } + + logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) + logging_system = setup_structured_logging( + hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + ) + + # Wait for it to connect... + await logging_system._observers[0]._service.whenConnected() + + start = perf_counter() + + # Send a bunch of useful messages + for i in range(0, loops): + logger.info("test message %s" % (i,)) + + if ( + len(logging_system._observers[0]._buffer) + == logging_system._observers[0].maximum_buffer + ): + while ( + len(logging_system._observers[0]._buffer) + > logging_system._observers[0].maximum_buffer / 2 + ): + await wait(0.01) + + await logger_factory.on_done + + end = perf_counter() - start + + logging_system.stop() + port.stopListening() + cleanup() + + return end diff --git a/sytest-blacklist b/sytest-blacklist index 411cce0692f2..79b2d4402aaa 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -33,3 +33,6 @@ New federated private chats get full presence information (SYN-115) # Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing # this requirement from the spec Inbound federation of state requires event_id as a mandatory paramater + +# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands +Can upload self-signing keys diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2dc50522496e..63d8633582ab 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py new file mode 100644 index 000000000000..27d83bb7d9b3 --- /dev/null +++ b/tests/federation/transport/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.unittest import override_config + + +class RoomDirectoryFederationTests(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + @override_config({"allow_public_rooms_over_federation": False}) + def test_blocked_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(403, channel.code) + + @override_config({"allow_public_rooms_over_federation": True}) + def test_open_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 854eb6c02450..fdfa2cbbc444 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,6 +183,10 @@ def test_replace_master_key(self): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + test_replace_master_key.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) + @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -503,3 +507,7 @@ def test_upload_signatures(self): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) + + test_upload_signatures.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e0075ccd328b..d9d312f0fb5e 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -42,16 +42,16 @@ def _add_background_updates(self): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ def _add_background_updates(self): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.db.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ def _get_current_stats(self, stats_type, stat_id): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.db.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -108,8 +108,12 @@ def _perform_background_initial_update(self): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ def test_initial_room(self): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,9 +186,9 @@ def test_initial_earliest_token(self): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_update_one( + self.store.db.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,14 +196,18 @@ def test_initial_earliest_token(self): ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -205,13 +217,13 @@ def test_initial_earliest_token(self): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -221,9 +233,13 @@ def test_initial_earliest_token(self): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -653,15 +669,15 @@ def test_incomplete_stats(self): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -673,9 +689,9 @@ def test_incomplete_stats(self): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +701,7 @@ def test_incomplete_stats(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +711,7 @@ def test_incomplete_stats(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -705,8 +721,12 @@ def test_incomplete_stats(self): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 31f54bbd7daf..758ee071a5cc 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,54 +12,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION -from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.handlers.sync import SyncConfig from synapse.types import UserID import tests.unittest import tests.utils -from tests.utils import setup_test_homeserver -class SyncTestCase(tests.unittest.TestCase): +class SyncTestCase(tests.unittest.HomeserverTestCase): """ Tests Sync Handler. """ - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.sync_handler = SyncHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.hs = hs + self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:server" user_id2 = "@user2:server" sync_config = self._generate_sync_config(user_id1) + self.reactor.advance(100) # So we get not 0 time self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 # Check that the happy case does not throw errors - yield self.store.upsert_monthly_active_user(user_id1) - yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works self.hs.config.hs_disabled = True - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f6d8660285a3..92b87260932e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -163,7 +163,9 @@ def test_started_typing_local(self): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -227,7 +229,9 @@ def test_started_typing_remote_recv(self): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -279,7 +283,9 @@ def test_stopped_typing(self): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -300,7 +306,9 @@ def test_typing_timeout(self): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -317,7 +325,9 @@ def test_typing_timeout(self): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -335,7 +345,9 @@ def test_typing_timeout(self): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41b8..26071059d246 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ def _compress_shared(self, shared): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ def get_users_in_public_rooms(self): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -181,10 +181,10 @@ def _add_background_updates(self): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -255,8 +255,12 @@ def test_initial(self): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() @@ -290,8 +294,12 @@ def test_initial_share_all_users(self): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index e7472e3a9310..3dae83c543eb 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.storage.database import Database from tests import unittest from tests.server import FakeTransport @@ -42,7 +43,9 @@ def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() - self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE( + Database(hs), self.hs.get_db_conn(), self.hs + ) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9575058252fb..0ed259438134 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ def test_purge_room(self): "state_groups_state", ): count = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 000000000000..5e9c07ebf3ff --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b687079c..0fdff79aa79a 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ def make_homeserver(self, reactor, clock): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b377278..12c5e95cb5b6 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ def setUp(self): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ def setUp(self): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index eda2fabc7195..1ca7fa742fe9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -813,105 +815,6 @@ def test_stream_token_is_accepted_for_fwd_pagianation(self): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_filter_labels(self): - """Test that we can filter by a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - - events = self._test_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_filter_labels(self, message_filter): - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&filter=%s" - % (self.room_id, token, message_filter), - ) - self.render(request) - - return channel.json_body["chunk"] - def test_room_messages_purge(self): store = self.hs.get_datastore() pagination_handler = self.hs.get_pagination_handler() @@ -1320,3 +1223,377 @@ def _check_for_reason(self, reason): event_content = channel.json_body self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 2, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 4, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 1, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30fb77bac8dc..4bc3aaf02d2e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -109,7 +109,9 @@ def test_set_typing(self): self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8ea0cb05ea52..e7417b3d14d3 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 3283c0e47b3e..661c1f88b9bb 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 9b81b536f546..d491ea29242b 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -323,7 +323,7 @@ def prepare(self, reactor, clock, hs): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ def prepare(self, reactor, clock, hs): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ def test_upsert_many(self): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ def test_upsert_many(self): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ def test_upsert_many(self): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ def test_upsert_many(self): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index dfeea24599ab..2e521e9ab7b2 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,6 +28,7 @@ ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database from tests import unittest from tests.utils import setup_test_homeserver @@ -54,7 +55,8 @@ def setUp(self): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -123,7 +125,8 @@ def setUp(self): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = TestTransactionStore(database, hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -382,8 +385,8 @@ def test_get_appservices_by_state_multiple(self): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): @@ -416,7 +419,7 @@ def test_unique_works(self): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -432,7 +435,7 @@ def test_duplicate_ids(self): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -453,7 +456,7 @@ def test_duplicate_as_tokens(self): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fabe3fbc009..aec76f4ab183 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -15,7 +15,7 @@ def setUp(self): self.update_handler = Mock() - yield self.store.register_background_update_handler( + yield self.store.db.updates.register_background_update_handler( "test_update", self.update_handler ) @@ -23,7 +23,7 @@ def setUp(self): # (perhaps we should run them as part of the test HS setup, since we # run all of the other schema setup stuff there?) while True: - res = yield self.store.do_next_background_update(1000) + res = yield self.store.db.updates.do_next_background_update(1000) if res is None: break @@ -37,9 +37,9 @@ def test_do_background_update(self): def update(progress, count): self.clock.advance_time_msec(count * duration_ms) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield self.store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.store.db.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,29 +47,37 @@ def update(progress, count): self.update_handler.side_effect = update - yield self.store.start_background_update("test_update", {"my_key": 1}) + yield self.store.db.updates.start_background_update( + "test_update", {"my_key": 1} + ) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + yield self.store.db.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c05..537cfe9f6494 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -59,13 +60,13 @@ def runWithConnection(func, *args, **kwargs): "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) - self.datastore = SQLBaseStore(None, hs) + self.datastore = SQLBaseStore(Database(hs), None, hs) @defer.inlineCallbacks def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +78,7 @@ def test_insert_1col(self): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +93,7 @@ def test_select_one_1col(self): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +107,7 @@ def test_select_one_3col(self): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +123,7 @@ def test_select_one_missing(self): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +138,7 @@ def test_select_list(self): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +151,7 @@ def test_select_list(self): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +166,7 @@ def test_update_one_1col(self): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +181,7 @@ def test_update_one_4cols(self): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 69dcaa63d58d..029ac264542d 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -46,7 +46,9 @@ def run_background_update(self): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, @@ -62,14 +64,20 @@ def run_delta_file(txn): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index afac5dec7f4e..bf674dd18458 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -37,9 +37,13 @@ def test_insert_new_client_ip(self): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,14 +51,14 @@ def test_insert_new_client_ip(self): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -81,7 +85,7 @@ def test_insert_new_client_ip_none_device_id(self): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +116,7 @@ def test_insert_new_client_ip_none_device_id(self): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -202,25 +206,31 @@ def test_updating_monthly_active_user_when_space(self): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) - # Force persisting to disk self.reactor.advance(200) # But clear the associated entry in devices table self.get_success( - self.store._simple_update( + self.store.db.simple_update( table="devices", - keyvalues={"user_id": user_id, "device_id": "device_id"}, + keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, desc="test_devices_last_seen_bg_update", ) @@ -228,14 +238,14 @@ def test_devices_last_seen_bg_update(self): # We should now get nulls when querying result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, @@ -245,7 +255,7 @@ def test_devices_last_seen_bg_update(self): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -256,22 +266,26 @@ def test_devices_last_seen_bg_update(self): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # We should now get the correct result again result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, @@ -281,14 +295,21 @@ def test_devices_last_seen_bg_update(self): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -297,7 +318,7 @@ def test_old_user_ips_pruned(self): # We should see that in the DB result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -312,7 +333,7 @@ def test_old_user_ips_pruned(self): "access_token": "access_token", "ip": "ip", "user_agent": "user_agent", - "device_id": "device_id", + "device_id": device_id, "last_seen": 0, } ], @@ -323,7 +344,7 @@ def test_old_user_ips_pruned(self): # We should get no results. result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -335,14 +356,14 @@ def test_old_user_ips_pruned(self): # But we should still get the correct values for the device result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fe50377f843..eadfb90a227d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -61,7 +61,7 @@ def insert_event(txn, i): ) for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others r = yield self.store.get_prev_events_for_room(room_id) @@ -93,9 +93,9 @@ def insert_event(txn, i, room_id): ) for i in range(0, 20): - yield self.store.runInteraction("insert", insert_event, i, room1) - yield self.store.runInteraction("insert", insert_event, i, room2) - yield self.store.runInteraction("insert", insert_event, i, room3) + yield self.store.db.runInteraction("insert", insert_event, i, room1) + yield self.store.db.runInteraction("insert", insert_event, i, room2) + yield self.store.db.runInteraction("insert", insert_event, i, room3) # Test simple case r = yield self.store.get_rooms_with_many_extremities(5, 5, []) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d4e..d4bcf1821e56 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -55,7 +55,7 @@ def test_count_aggregation(self): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,7 +74,7 @@ def _inject_actions(stream, action): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -82,12 +82,12 @@ def _inject_actions(stream, action): ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +116,7 @@ def _mark_read(stream, depth): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ def _mark_read(stream, depth): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 90a63dc47784..3c78faab4528 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -65,7 +65,7 @@ def test_initialise_reserved_users(self): self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -183,7 +183,7 @@ def test_reap_monthly_active_users_reserved_users(self): ) self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) count = self.store.get_monthly_active_count() @@ -244,7 +244,7 @@ def test_get_reserved_real_user_account(self): {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24c7fe16c329..9b6f7211aefe 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4561c3e383dc..dc4517335520 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ def test_redact_censor(self): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ def test_redact_censor(self): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 105a0c2b0201..7840f63fe381 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -122,8 +122,12 @@ def prepare(self, reactor, clock, homeserver): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -132,7 +136,7 @@ def test_can_rerun_update(self): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -143,8 +147,12 @@ def test_can_rerun_update(self): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7eea57c0e23d..6a545d2eb028 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 7d82b584664d..ad165d729591 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -33,6 +33,8 @@ def setUp(self): self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] + self.store = self.homeserver.get_datastore() + # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( @@ -77,10 +79,7 @@ def setUp(self): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -100,10 +99,7 @@ def post_json(destination, path, data, headers=None, timeout=0): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event @@ -141,7 +137,5 @@ def post_json(destination, path, data, headers=None, timeout=0): ) # Make sure the invalid event isn't there - extrem = maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id - ) + extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") diff --git a/tests/unittest.py b/tests/unittest.py index 31997a0f31c8..b30b7d171819 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -18,6 +18,7 @@ import gc import hashlib import hmac +import inspect import logging import time @@ -25,7 +26,7 @@ from canonicaljson import json -from twisted.internet.defer import Deferred, succeed +from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -401,10 +402,12 @@ def setup_test_homeserver(self, *args, **kwargs): hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() - # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not self.get_success(stor.has_completed_background_updates()): - self.get_success(stor.do_next_background_update(1)) + # Run the database background updates, when running against "master". + if hs.__class__.__name__ == "TestHomeServer": + while not self.get_success( + stor.db.updates.has_completed_background_updates() + ): + self.get_success(stor.db.updates.do_next_background_update(1)) return hs @@ -415,6 +418,8 @@ def pump(self, by=0.0): self.reactor.pump([by] * 100) def get_success(self, d, by=0.0): + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump(by=by) @@ -424,6 +429,8 @@ def get_failure(self, d, exc): """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump() @@ -544,7 +551,7 @@ def add_extremity(self, room_id, event_id): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore()._simple_insert( + self.hs.get_datastore().db.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7b0..281b32c4b893 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -179,6 +179,30 @@ def test_nested_logging_context(self): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_await(self): + # an async function which retuns an incomplete coroutine, but doesn't + # follow the synapse rules. + + async def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + await d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py deleted file mode 100644 index 1a44f7242551..000000000000 --- a/tests/util/test_snapshot_cache.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet.defer import Deferred - -from synapse.util.caches.snapshot_cache import SnapshotCache - -from .. import unittest - - -class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = SnapshotCache() - self.cache.DURATION_MS = 1 - - def test_get_set(self): - # Check that getting a missing key returns None - self.assertEquals(self.cache.get(0, "key"), None) - - # Check that setting a key with a deferred returns - # a deferred that resolves when the initial deferred does - d = Deferred() - set_result = self.cache.set(0, "key", d) - self.assertIsNotNone(set_result) - self.assertFalse(set_result.called) - - # Check that getting the key before the deferred has resolved - # returns a deferred that resolves when the initial deferred does. - get_result_at_10 = self.cache.get(10, "key") - self.assertIsNotNone(get_result_at_10) - self.assertFalse(get_result_at_10.called) - - # Check that the returned deferreds resolve when the initial deferred - # does. - d.callback("v") - self.assertTrue(set_result.called) - self.assertTrue(get_result_at_10.called) - - # Check that getting the key after the deferred has resolved - # before the cache expires returns a resolved deferred. - get_result_at_11 = self.cache.get(11, "key") - self.assertIsNotNone(get_result_at_11) - if isinstance(get_result_at_11, Deferred): - # The cache may return the actual result rather than a deferred - self.assertTrue(get_result_at_11.called) - - # Check that getting the key after the deferred has resolved - # after the cache expires returns None - get_result_at_12 = self.cache.get(12, "key") - self.assertIsNone(get_result_at_12) diff --git a/tests/utils.py b/tests/utils.py index de2ac1ed339a..c57da59191cc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -461,7 +461,9 @@ def trigger( try: args = [urlparse.unquote(u) for u in matcher.groups()] - (code, response) = yield func(mock_request, *args) + (code, response) = yield defer.ensureDeferred( + func(mock_request, *args) + ) return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) diff --git a/tox.ini b/tox.ini index 62b350ea6a94..903a245fb027 100644 --- a/tox.ini +++ b/tox.ini @@ -102,6 +102,15 @@ commands = {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} +[testenv:benchmark] +deps = + {[base]deps} + pyperf +setenv = + SYNAPSE_POSTGRES = 1 +commands = + python -m synmark {posargs:} + [testenv:packaging] skip_install=True deps =