Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Tidy up and type-hint the database engine modules #12734

Merged
merged 10 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions synapse/storage/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping

from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine


def create_engine(database_config) -> BaseDatabaseEngine:
def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
name = database_config["name"]

if name == "sqlite3":
import sqlite3

return Sqlite3Engine(sqlite3, database_config)
return Sqlite3Engine(database_config)

if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
import psycopg2

return PostgresEngine(psycopg2, database_config)
return PostgresEngine(database_config)

raise RuntimeError("Unsupported database engine '%s'" % (name,))

Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/engines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import abc
from enum import IntEnum
from typing import Generic, Optional, TypeVar
from typing import Any, Generic, Mapping, Optional, TypeVar

from synapse.storage.types import Connection

Expand All @@ -32,7 +32,7 @@ class IncorrectDatabaseSetup(RuntimeError):


class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def __init__(self, module, database_config: dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poor commit hygiene.

def __init__(self, module, database_config: Mapping[str, Any]):
self.module = module

@property
Expand Down
37 changes: 23 additions & 14 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Mapping, Optional
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast

from synapse.storage.engines._base import (
BaseDatabaseEngine,
Expand All @@ -22,30 +22,35 @@
)
from synapse.storage.types import Connection

if TYPE_CHECKING:
import psycopg2 # noqa: F401

logger = logging.getLogger(__name__)


class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
def __init__(self, database_config: Mapping[str, Any]):
import psycopg2.extensions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike importing this here rather than at the top level. But this was the least-bad/easiest way I could see to not require existing installations to install psycopg2.


super().__init__(psycopg2, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)

# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_):
raise Exception("Passing bytes to DB is disabled.")

self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
self._version: Optional[int] = None # unknown as yet

self.isolation_level_map: Mapping[int, int] = {
IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config

Expand All @@ -65,7 +70,7 @@ def check_database(self, db_conn, allow_outdated_version: bool = False):
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
self._version = cast(int, db_conn.server_version)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If accepted, python/typeshed#7834 would make this cast redundant.

allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)

# Are we on a supported PostgreSQL version?
Expand Down Expand Up @@ -166,7 +171,9 @@ def supports_returning(self) -> bool:
return True

def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
import psycopg2.extensions

if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
Expand Down Expand Up @@ -198,7 +205,9 @@ def server_version(self):
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)

def in_transaction(self, conn: Connection) -> bool:
return conn.status != self.module.extensions.STATUS_READY # type: ignore
import psycopg2.extensions

return conn.status != psycopg2.extensions.STATUS_READY # type: ignore

def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
return conn.set_session(autocommit=autocommit) # type: ignore
Expand Down
23 changes: 10 additions & 13 deletions synapse/storage/engines/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import sqlite3
import struct
import threading
import typing
from typing import Optional
from typing import Any, Mapping, Optional

from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Connection

if typing.TYPE_CHECKING:
import sqlite3 # noqa: F401


class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def __init__(self, database_module, database_config):
super().__init__(database_module, database_config)
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_config: Mapping[str, Any]):
super().__init__(sqlite3, database_config)

database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (
Expand All @@ -37,7 +34,7 @@ def __init__(self, database_module, database_config):
if platform.python_implementation() == "PyPy":
# pypy's sqlite3 module doesn't handle bytearrays, convert them
# back to bytes.
database_module.register_adapter(bytearray, lambda array: bytes(array))
sqlite3.register_adapter(bytearray, lambda array: bytes(array))

# The current max state_group, or None if we haven't looked
# in the DB yet.
Expand All @@ -54,7 +51,7 @@ def can_native_upsert(self):
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
more work we haven't done yet to tell what was inserted vs updated.
"""
return self.module.sqlite_version_info >= (3, 24, 0)
return sqlite3.sqlite_version_info >= (3, 24, 0)

@property
def supports_using_any_list(self):
Expand All @@ -64,11 +61,11 @@ def supports_using_any_list(self):
@property
def supports_returning(self) -> bool:
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return self.module.sqlite_version_info >= (3, 35, 0)
return sqlite3.sqlite_version_info >= (3, 35, 0)

def check_database(self, db_conn, allow_outdated_version: bool = False):
if not allow_outdated_version:
version = self.module.sqlite_version_info
version = sqlite3.sqlite_version_info
# Synapse is untested against older SQLite versions, and we don't want
# to let users upgrade to a version of Synapse with broken support for their
# sqlite version, because it risks leaving them with a half-upgraded db.
Expand Down Expand Up @@ -113,7 +110,7 @@ def server_version(self):
Returns:
string
"""
return "%i.%i.%i" % self.module.sqlite_version_info
return "%i.%i.%i" % sqlite3.sqlite_version_info

def in_transaction(self, conn: Connection) -> bool:
return conn.in_transaction # type: ignore
Expand Down