Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Tidy up and type-hint the database engine modules #12734

Merged
merged 10 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12734.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Tidy up and type-hint the database engine modules.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True

[mypy-synapse.storage.engines.*]
disallow_untyped_defs = True

[mypy-synapse.storage.prepare_database]
disallow_untyped_defs = True

Expand Down
12 changes: 4 additions & 8 deletions synapse/storage/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping

from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine


def create_engine(database_config) -> BaseDatabaseEngine:
def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
name = database_config["name"]

if name == "sqlite3":
import sqlite3

return Sqlite3Engine(sqlite3, database_config)
return Sqlite3Engine(database_config)

if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
import psycopg2

return PostgresEngine(psycopg2, database_config)
return PostgresEngine(database_config)

raise RuntimeError("Unsupported database engine '%s'" % (name,))

Expand Down
26 changes: 16 additions & 10 deletions synapse/storage/engines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.
import abc
from enum import IntEnum
from typing import Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar

from synapse.storage.types import Connection
from synapse.storage.types import Connection, Cursor, DBAPI2Module

if TYPE_CHECKING:
from synapse.storage.database import LoggingDatabaseConnection


class IsolationLevel(IntEnum):
Expand All @@ -32,7 +35,7 @@ class IncorrectDatabaseSetup(RuntimeError):


class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def __init__(self, module, database_config: dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poor commit hygiene.

def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
self.module = module

@property
Expand Down Expand Up @@ -69,7 +72,7 @@ def check_database(
...

@abc.abstractmethod
def check_new_database(self, txn) -> None:
def check_new_database(self, txn: Cursor) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
Expand All @@ -79,8 +82,11 @@ def check_new_database(self, txn) -> None:
def convert_param_style(self, sql: str) -> str:
...

# This method would ideally take a plain ConnectionType, but it seems that
# the Sqlite engine expects to use LoggingDatabaseConnection.cursor
# instead of sqlite3.Connection.cursor: only the former takes a txn_name.
@abc.abstractmethod
def on_new_connection(self, db_conn: ConnectionType) -> None:
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
...

@abc.abstractmethod
Expand All @@ -92,7 +98,7 @@ def is_connection_closed(self, conn: ConnectionType) -> bool:
...

@abc.abstractmethod
def lock_table(self, txn, table: str) -> None:
def lock_table(self, txn: Cursor, table: str) -> None:
...

@property
Expand All @@ -102,12 +108,12 @@ def server_version(self) -> str:
...

@abc.abstractmethod
def in_transaction(self, conn: Connection) -> bool:
def in_transaction(self, conn: ConnectionType) -> bool:
"""Whether the connection is currently in a transaction."""
...

@abc.abstractmethod
def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None:
"""Attempt to set the connections autocommit mode.

When True queries are run outside of transactions.
Expand All @@ -119,8 +125,8 @@ def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):

@abc.abstractmethod
def attempt_to_set_isolation_level(
self, conn: Connection, isolation_level: Optional[int]
):
self, conn: ConnectionType, isolation_level: Optional[int]
) -> None:
"""Attempt to set the connections isolation level.

Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
Expand Down
92 changes: 52 additions & 40 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,69 @@
# limitations under the License.

import logging
from typing import Mapping, Optional
from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast

from synapse.storage.engines._base import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
IsolationLevel,
)
from synapse.storage.types import Connection
from synapse.storage.types import Cursor

if TYPE_CHECKING:
import psycopg2 # noqa: F401

from synapse.storage.database import LoggingDatabaseConnection


logger = logging.getLogger(__name__)


class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
def __init__(self, database_config: Mapping[str, Any]):
import psycopg2.extensions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike importing this here rather than at the top level. But this was the least-bad/easiest way I could see to not require existing installations to install psycopg2.


super().__init__(psycopg2, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)

# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_):
def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.")

self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
self._version: Optional[int] = None # unknown as yet

self.isolation_level_map: Mapping[int, int] = {
IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config

@property
def single_threaded(self) -> bool:
return False

def get_db_locale(self, txn):
def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
collation, ctype = cast(Tuple[str, str], txn.fetchone())
return collation, ctype

def check_database(self, db_conn, allow_outdated_version: bool = False):
def check_database(
self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False
) -> None:
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
self._version = cast(int, db_conn.server_version)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If accepted, python/typeshed#7834 would make this cast redundant.

allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)

# Are we on a supported PostgreSQL version?
Expand Down Expand Up @@ -108,7 +118,7 @@ def check_database(self, db_conn, allow_outdated_version: bool = False):
ctype,
)

def check_new_database(self, txn):
def check_new_database(self, txn: Cursor) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
Expand All @@ -129,10 +139,10 @@ def check_new_database(self, txn):
"See docs/postgres.md for more information." % ("\n".join(errors))
)

def convert_param_style(self, sql):
def convert_param_style(self, sql: str) -> str:
return sql.replace("?", "%s")

def on_new_connection(self, db_conn):
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
db_conn.set_isolation_level(self.default_isolation_level)

# Set the bytea output to escape, vs the default of hex
Expand All @@ -149,14 +159,14 @@ def on_new_connection(self, db_conn):
db_conn.commit()

@property
def can_native_upsert(self):
def can_native_upsert(self) -> bool:
"""
Can we use native UPSERTs?
"""
return True

@property
def supports_using_any_list(self):
def supports_using_any_list(self) -> bool:
"""Do we support using `a = ANY(?)` and passing a list"""
return True

Expand All @@ -165,27 +175,25 @@ def supports_returning(self) -> bool:
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True

def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
def is_deadlock(self, error: Exception) -> bool:
import psycopg2.extensions

if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"]
return False

def is_connection_closed(self, conn):
def is_connection_closed(self, conn: "psycopg2.connection") -> bool:
return bool(conn.closed)

def lock_table(self, txn, table):
def lock_table(self, txn: Cursor, table: str) -> None:
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))

@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'

Returns:
string
"""
def server_version(self) -> str:
"""Returns a string giving the server version. For example: '8.1.5'."""
# note that this is a bit of a hack because it relies on check_database
# having been called. Still, that should be a safe bet here.
numver = self._version
Expand All @@ -197,17 +205,21 @@ def server_version(self):
else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)

def in_transaction(self, conn: Connection) -> bool:
return conn.status != self.module.extensions.STATUS_READY # type: ignore
def in_transaction(self, conn: "psycopg2.connection") -> bool:
import psycopg2.extensions

return conn.status != psycopg2.extensions.STATUS_READY
Comment on lines +208 to +211
Copy link
Contributor

@squahtx squahtx May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we'll be running the import on every database transaction. I think it's fine?
Reimporting a previously imported module didn't seem to take too long in a very unscientific test using the repl:

>>> t0 = time.time(); import psycopg2.extensions; time.time() - t0
1.0013580322265625e-05 (~10 microseconds, worst number)
>>> t0 = time.time(); time.time() - t0
2.384185791015625e-06 (~2.5 microseconds, worst number)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd call it "ugly, but fine".

The obstacle blocking us from importing psycopg2 freely in this module is that we except to be able to import PostgresEngine everywhere, even if psycopg2 is not available. For example:

if isinstance(database.engine, PostgresEngine):

if isinstance(self.database_engine, PostgresEngine):

if isinstance(self.database_engine, PostgresEngine):

I mulled over doing something like

  • rename postgres.py to _postgres.py
  • create a new file postgres.py with contents like this:
    try:
        from ._postgres import PostgresEngine
    except ImportError:
        class PostgresEngine: pass
    
    __all__ = ["PostgresEngine"] 

But I'm not sure how well Pycharm and mypy will handle that. Other ideas:

  • have DatabaseEngine have some kind of "engine_kind" attribute which points to an enum. Felt like duplicating information we already know, and just generally gave me a smelly impression.
  • make psycopg2 a mandatory requirement for Synapse, That'd be my preferred option, but I suspect it'd be controversial.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DMRobertson another idea would be to add more flags to database engine that talk about the features as opposed to the specific class. E.g. it looks like those need (respectively):

  • Sequence support
  • RECURSIVE support
  • DISTINCT ON support

E.g. expand the features listed near

@property
def can_native_upsert(self):
"""
Can we use native UPSERTs?
"""
return True
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list"""
return True
@property
def supports_returning(self) -> bool:
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, though I think it'd be tricky to retrofit those in across the source tree without being familiar with the differences between the two databases.

(Obligatory grumble: I don't want us to be rewriting and maintaining our own version of SQLAlchemy core)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely fair, not trying to scope creep! 👍 I've thought about doing this in the past also!


def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
return conn.set_session(autocommit=autocommit) # type: ignore
def attempt_to_set_autocommit(
self, conn: "psycopg2.connection", autocommit: bool
) -> None:
return conn.set_session(autocommit=autocommit)

def attempt_to_set_isolation_level(
self, conn: Connection, isolation_level: Optional[int]
):
self, conn: "psycopg2.connection", isolation_level: Optional[int]
) -> None:
if isolation_level is None:
isolation_level = self.default_isolation_level
else:
isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(isolation_level) # type: ignore
return conn.set_isolation_level(isolation_level)
Loading