diff --git a/.travis.yml b/.travis.yml index 6bbc7e7d..1506b4ec 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: - "3.7" env: - - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database" + - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db" services: - postgresql diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index ddeb08cb..7408d899 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -5,7 +5,8 @@ import aiomysql from sqlalchemy.dialects.mysql import pymysql -from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.engine.interfaces import Dialect, ExecutionContext +from sqlalchemy.engine.result import ResultMetaData, RowProxy from sqlalchemy.sql import ClauseElement from sqlalchemy.types import TypeEngine @@ -14,12 +15,10 @@ logger = logging.getLogger("databases") -_result_processors = {} # type: dict - class MySQLBackend(DatabaseBackend): - def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None: - self._database_url = DatabaseURL(database_url) + def __init__(self, database_url: DatabaseURL) -> None: + self._database_url = database_url self._dialect = pymysql.dialect(paramstyle="pyformat") self._pool = None @@ -45,28 +44,9 @@ def connection(self) -> "MySQLConnection": return MySQLConnection(self._pool, self._dialect) -class Record: - def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None: - self._row = row - self._result_columns = result_columns - self._dialect = dialect - self._column_map = { - column_name: (idx, datatype) - for idx, (column_name, _, _, datatype) in enumerate(self._result_columns) - } - - def __getitem__(self, key: str) -> typing.Any: - idx, datatype = self._column_map[key] - raw = self._row[idx] - try: - processor = _result_processors[datatype] - except KeyError: - processor = datatype.result_processor(self._dialect, None) - _result_processors[datatype] = processor - - if processor is not None: - return processor(raw) - return raw +class CompilationContext: + def __init__(self, context: ExecutionContext): + self.context = context class MySQLConnection(ConnectionBackend): @@ -84,25 +64,30 @@ async def release(self) -> None: await self._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.Any: + async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]: assert self._connection is not None, "Connection is not acquired" - query, args, result_columns = self._compile(query) + query, args, context = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query, args) rows = await cursor.fetchall() - return [Record(row, result_columns, self._dialect) for row in rows] + metadata = ResultMetaData(context, cursor.description) + return [ + RowProxy(metadata, row, metadata._processors, metadata._keymap) + for row in rows + ] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Any: + async def fetch_one(self, query: ClauseElement) -> RowProxy: assert self._connection is not None, "Connection is not acquired" - query, args, result_columns = self._compile(query) + query, args, context = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query, args) row = await cursor.fetchone() - return Record(row, result_columns, self._dialect) + metadata = ResultMetaData(context, cursor.description) + return RowProxy(metadata, row, metadata._processors, metadata._keymap) finally: await cursor.close() @@ -110,7 +95,7 @@ async def execute(self, query: ClauseElement, values: dict = None) -> None: assert self._connection is not None, "Connection is not acquired" if values is not None: query = query.values(values) - query, args, result_columns = self._compile(query) + query, args, context = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query, args) @@ -123,7 +108,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None: try: for item in values: single_query = query.values(item) - single_query, args, result_columns = self._compile(single_query) + single_query, args, context = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -132,26 +117,38 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query, args, result_columns = self._compile(query) + query, args, context = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query, args) + metadata = ResultMetaData(context, cursor.description) async for row in cursor: - yield Record(row, result_columns, self._dialect) + yield RowProxy(metadata, row, metadata._processors, metadata._keymap) finally: await cursor.close() def transaction(self) -> TransactionBackend: return MySQLTransaction(self) - def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: + def _compile( + self, query: ClauseElement + ) -> typing.Tuple[str, dict, CompilationContext]: compiled = query.compile(dialect=self._dialect) args = compiled.construct_params() - logger.debug(compiled.string, args) for key, val in args.items(): if key in compiled._bind_processors: args[key] = compiled._bind_processors[key](val) - return compiled.string, args, compiled._result_columns + + execution_context = self._dialect.execution_ctx_cls() + execution_context.dialect = self._dialect + execution_context.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) + + logger.debug(compiled.string, args) + return compiled.string, args, CompilationContext(execution_context) class MySQLTransaction(TransactionBackend): @@ -176,10 +173,7 @@ async def start(self, is_root: bool) -> None: async def commit(self) -> None: assert self._connection._connection is not None, "Connection is not acquired" - if self._is_root: # pragma: no cover - # In test cases the root transaction is never committed, - # since we *always* wrap the test case up in a transaction - # and rollback to a clean state at the end. + if self._is_root: await self._connection._connection.commit() else: cursor = await self._connection._connection.cursor() diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 8579a8dc..301bb96a 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -16,8 +16,8 @@ class PostgresBackend(DatabaseBackend): - def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None: - self._database_url = DatabaseURL(database_url) + def __init__(self, database_url: DatabaseURL) -> None: + self._database_url = database_url self._dialect = self._get_dialect() self._pool = None diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py new file mode 100644 index 00000000..bf24f226 --- /dev/null +++ b/databases/backends/sqlite.py @@ -0,0 +1,194 @@ +import logging +import typing +import uuid + +import aiosqlite +from sqlalchemy.dialects.sqlite import pysqlite +from sqlalchemy.engine.interfaces import Dialect, ExecutionContext +from sqlalchemy.engine.result import ResultMetaData, RowProxy +from sqlalchemy.sql import ClauseElement +from sqlalchemy.types import TypeEngine + +from databases.core import DatabaseURL +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend + +logger = logging.getLogger("databases") + + +class SQLiteBackend(DatabaseBackend): + def __init__(self, database_url: DatabaseURL) -> None: + self._database_url = database_url + self._dialect = pysqlite.dialect(paramstyle="qmark") + self._pool = SQLitePool(database_url) + + async def connect(self) -> None: + pass + # assert self._pool is None, "DatabaseBackend is already running" + # self._pool = await aiomysql.create_pool( + # host=self._database_url.hostname, + # port=self._database_url.port or 3306, + # user=self._database_url.username or getpass.getuser(), + # password=self._database_url.password, + # db=self._database_url.database, + # autocommit=True, + # ) + + async def disconnect(self) -> None: + pass + # assert self._pool is not None, "DatabaseBackend is not running" + # self._pool.close() + # await self._pool.wait_closed() + # self._pool = None + + def connection(self) -> "SQLiteConnection": + return SQLiteConnection(self._pool, self._dialect) + + +class SQLitePool: + def __init__(self, url: DatabaseURL) -> None: + self._url = url + + async def acquire(self) -> aiosqlite.Connection: + connection = aiosqlite.connect( + database=self._url.database, isolation_level=None + ) + await connection.__aenter__() + return connection + + async def release(self, connection: aiosqlite.Connection) -> None: + await connection.__aexit__(None, None, None) + + +class CompilationContext: + def __init__(self, context: ExecutionContext): + self.context = context + + +class SQLiteConnection(ConnectionBackend): + def __init__(self, pool: SQLitePool, dialect: Dialect): + self._pool = pool + self._dialect = dialect + self._connection = None + + async def acquire(self) -> None: + assert self._connection is None, "Connection is already acquired" + self._connection = await self._pool.acquire() + + async def release(self) -> None: + assert self._connection is not None, "Connection is not acquired" + await self._pool.release(self._connection) + self._connection = None + + async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + + async with self._connection.execute(query, args) as cursor: + rows = await cursor.fetchall() + metadata = ResultMetaData(context, cursor.description) + return [ + RowProxy(metadata, row, metadata._processors, metadata._keymap) + for row in rows + ] + + async def fetch_one(self, query: ClauseElement) -> RowProxy: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + + async with self._connection.execute(query, args) as cursor: + row = await cursor.fetchone() + metadata = ResultMetaData(context, cursor.description) + return RowProxy(metadata, row, metadata._processors, metadata._keymap) + + async def execute(self, query: ClauseElement, values: dict = None) -> None: + assert self._connection is not None, "Connection is not acquired" + if values is not None: + query = query.values(values) + query, args, context = self._compile(query) + cursor = await self._connection.execute(query, args) + await cursor.close() + + async def execute_many(self, query: ClauseElement, values: list) -> None: + assert self._connection is not None, "Connection is not acquired" + for value in values: + await self.execute(query, value) + + async def iterate( + self, query: ClauseElement + ) -> typing.AsyncGenerator[typing.Any, None]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + async with self._connection.execute(query, args) as cursor: + metadata = ResultMetaData(context, cursor.description) + async for row in cursor: + yield RowProxy(metadata, row, metadata._processors, metadata._keymap) + + def transaction(self) -> TransactionBackend: + return SQLiteTransaction(self) + + def _compile( + self, query: ClauseElement + ) -> typing.Tuple[str, list, CompilationContext]: + compiled = query.compile(dialect=self._dialect) + args = [] + for key, raw_val in compiled.construct_params().items(): + if key in compiled._bind_processors: + val = compiled._bind_processors[key](raw_val) + else: + val = raw_val + args.append(val) + + execution_context = self._dialect.execution_ctx_cls() + execution_context.dialect = self._dialect + execution_context.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) + + logger.debug(compiled.string, args) + return compiled.string, args, CompilationContext(execution_context) + + +class SQLiteTransaction(TransactionBackend): + def __init__(self, connection: SQLiteConnection): + self._connection = connection + self._is_root = False + self._savepoint_name = "" + + async def start(self, is_root: bool) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + self._is_root = is_root + if self._is_root: + cursor = await self._connection._connection.execute("BEGIN") + await cursor.close() + else: + id = str(uuid.uuid4()).replace("-", "_") + self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" + cursor = await self._connection._connection.execute( + f"SAVEPOINT {self._savepoint_name}" + ) + await cursor.close() + + async def commit(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + if self._is_root: + cursor = await self._connection._connection.execute("COMMIT") + await cursor.close() + else: + cursor = await self._connection._connection.execute( + f"RELEASE SAVEPOINT {self._savepoint_name}" + ) + await cursor.close() + + async def rollback(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + if self._is_root: + cursor = await self._connection._connection.execute("ROLLBACK") + await cursor.close() + else: + cursor = await self._connection._connection.execute( + f"ROLLBACK TO SAVEPOINT {self._savepoint_name}" + ) + await cursor.close() diff --git a/databases/core.py b/databases/core.py index 644a9619..2bf778aa 100644 --- a/databases/core.py +++ b/databases/core.py @@ -5,6 +5,7 @@ from types import TracebackType from urllib.parse import SplitResult, urlsplit +from sqlalchemy.engine import RowProxy from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string @@ -20,6 +21,7 @@ class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", "mysql": "databases.backends.mysql:MySQLBackend", + "sqlite": "databases.backends.sqlite:SQLiteBackend", } def __init__( @@ -27,7 +29,6 @@ def __init__( ): self._url = DatabaseURL(url) self._force_rollback = force_rollback - self.is_connected = False backend_str = self.SUPPORTED_BACKENDS[self._url.dialect] @@ -44,6 +45,9 @@ def __init__( self._global_transaction = None # type: typing.Optional[Transaction] async def connect(self) -> None: + """ + Establish the connection pool. + """ assert not self.is_connected, "Already connected." await self._backend.connect() @@ -57,6 +61,9 @@ async def connect(self) -> None: await self._global_transaction.__aenter__() async def disconnect(self) -> None: + """ + Close all connections in the connection pool. + """ assert self.is_connected, "Already disconnected." if self._force_rollback: @@ -80,11 +87,11 @@ async def __aexit__( ) -> None: await self.disconnect() - async def fetch_all(self, query: ClauseElement) -> typing.Any: + async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]: async with self.connection() as connection: return await connection.fetch_all(query=query) - async def fetch_one(self, query: ClauseElement) -> typing.Any: + async def fetch_one(self, query: ClauseElement) -> RowProxy: async with self.connection() as connection: return await connection.fetch_one(query=query) @@ -98,7 +105,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[RowProxy, None]: async with self.connection() as connection: async for record in connection.iterate(query): yield record @@ -240,10 +247,7 @@ async def rollback(self) -> None: class DatabaseURL: def __init__(self, url: typing.Union[str, "DatabaseURL"]): - if isinstance(url, DatabaseURL): - self._url = str(url) - else: - self._url = url + self._url = str(url) @property def components(self) -> SplitResult: diff --git a/databases/interfaces.py b/databases/interfaces.py index d611ae36..5d4e4cce 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -1,5 +1,6 @@ import typing +from sqlalchemy.engine import RowProxy from sqlalchemy.sql import ClauseElement @@ -21,10 +22,10 @@ async def acquire(self) -> None: async def release(self) -> None: raise NotImplementedError() # pragma: no cover - async def fetch_all(self, query: ClauseElement) -> typing.Any: + async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]: raise NotImplementedError() # pragma: no cover - async def fetch_one(self, query: ClauseElement) -> typing.Any: + async def fetch_one(self, query: ClauseElement) -> RowProxy: raise NotImplementedError() # pragma: no cover async def execute(self, query: ClauseElement, values: dict = None) -> None: @@ -35,7 +36,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[RowProxy, None]: raise NotImplementedError() # pragma: no cover # mypy needs async iterators to contain a `yield` # https://github.com/python/mypy/issues/5385#issuecomment-407281656 diff --git a/requirements.txt b/requirements.txt index 46e0a7c1..83725adb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ -sqlalchemy +sqlalchemy==1.3.0b3 aiocontextvars;python_version<"3.7" # Async database drivers -asyncpg aiomysql +aiosqlite +asyncpg # Sync database drivers for standard tooling around setup/teardown/migrations. psycopg2-binary diff --git a/tests/test_databases.py b/tests/test_databases.py index a9175ff7..b1a974a1 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -376,6 +376,30 @@ async def test_connections_isolation(database_url): await database.execute(query) +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_commit_on_root_transaction(database_url): + """ + Because our tests are generally wrapped in rollback-islation, they + don't have coverage for commiting the root transaction. + + Deal with this here, and delete the records rather than rolling back. + """ + + async with Database(database_url) as database: + try: + async with database.transaction(): + query = notes.insert().values(text="example1", completed=True) + await database.execute(query) + + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + finally: + query = notes.delete() + await database.execute(query) + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_connect_and_disconnect(database_url):