Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ python:
- "3.7"

env:
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database"
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db"

services:
- postgresql
Expand Down
82 changes: 38 additions & 44 deletions databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import aiomysql
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.sql import ClauseElement
from sqlalchemy.types import TypeEngine

Expand All @@ -14,12 +15,10 @@

logger = logging.getLogger("databases")

_result_processors = {} # type: dict


class MySQLBackend(DatabaseBackend):
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
self._database_url = DatabaseURL(database_url)
def __init__(self, database_url: DatabaseURL) -> None:
self._database_url = database_url
self._dialect = pymysql.dialect(paramstyle="pyformat")
self._pool = None

Expand All @@ -45,28 +44,9 @@ def connection(self) -> "MySQLConnection":
return MySQLConnection(self._pool, self._dialect)


class Record:
def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
self._row = row
self._result_columns = result_columns
self._dialect = dialect
self._column_map = {
column_name: (idx, datatype)
for idx, (column_name, _, _, datatype) in enumerate(self._result_columns)
}

def __getitem__(self, key: str) -> typing.Any:
idx, datatype = self._column_map[key]
raw = self._row[idx]
try:
processor = _result_processors[datatype]
except KeyError:
processor = datatype.result_processor(self._dialect, None)
_result_processors[datatype] = processor

if processor is not None:
return processor(raw)
return raw
class CompilationContext:
def __init__(self, context: ExecutionContext):
self.context = context


class MySQLConnection(ConnectionBackend):
Expand All @@ -84,33 +64,38 @@ async def release(self) -> None:
await self._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.Any:
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
query, args, context = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
rows = await cursor.fetchall()
return [Record(row, result_columns, self._dialect) for row in rows]
metadata = ResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
for row in rows
]
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Any:
async def fetch_one(self, query: ClauseElement) -> RowProxy:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
query, args, context = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
row = await cursor.fetchone()
return Record(row, result_columns, self._dialect)
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
finally:
await cursor.close()

async def execute(self, query: ClauseElement, values: dict = None) -> None:
assert self._connection is not None, "Connection is not acquired"
if values is not None:
query = query.values(values)
query, args, result_columns = self._compile(query)
query, args, context = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
Expand All @@ -123,7 +108,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
try:
for item in values:
single_query = query.values(item)
single_query, args, result_columns = self._compile(single_query)
single_query, args, context = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
await cursor.close()
Expand All @@ -132,26 +117,38 @@ async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
query, args, context = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
metadata = ResultMetaData(context, cursor.description)
async for row in cursor:
yield Record(row, result_columns, self._dialect)
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
finally:
await cursor.close()

def transaction(self) -> TransactionBackend:
return MySQLTransaction(self)

def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, dict, CompilationContext]:
compiled = query.compile(dialect=self._dialect)
args = compiled.construct_params()
logger.debug(compiled.string, args)
for key, val in args.items():
if key in compiled._bind_processors:
args[key] = compiled._bind_processors[key](val)
return compiled.string, args, compiled._result_columns

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect
execution_context.result_column_struct = (
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
)

logger.debug(compiled.string, args)
return compiled.string, args, CompilationContext(execution_context)


class MySQLTransaction(TransactionBackend):
Expand All @@ -176,10 +173,7 @@ async def start(self, is_root: bool) -> None:

async def commit(self) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
if self._is_root: # pragma: no cover
# In test cases the root transaction is never committed,
# since we *always* wrap the test case up in a transaction
# and rollback to a clean state at the end.
if self._is_root:
await self._connection._connection.commit()
else:
cursor = await self._connection._connection.cursor()
Expand Down
4 changes: 2 additions & 2 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


class PostgresBackend(DatabaseBackend):
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
self._database_url = DatabaseURL(database_url)
def __init__(self, database_url: DatabaseURL) -> None:
self._database_url = database_url
self._dialect = self._get_dialect()
self._pool = None

Expand Down
194 changes: 194 additions & 0 deletions databases/backends/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import logging
import typing
import uuid

import aiosqlite
from sqlalchemy.dialects.sqlite import pysqlite
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.sql import ClauseElement
from sqlalchemy.types import TypeEngine

from databases.core import DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")


class SQLiteBackend(DatabaseBackend):
def __init__(self, database_url: DatabaseURL) -> None:
self._database_url = database_url
self._dialect = pysqlite.dialect(paramstyle="qmark")
self._pool = SQLitePool(database_url)

async def connect(self) -> None:
pass
# assert self._pool is None, "DatabaseBackend is already running"
# self._pool = await aiomysql.create_pool(
# host=self._database_url.hostname,
# port=self._database_url.port or 3306,
# user=self._database_url.username or getpass.getuser(),
# password=self._database_url.password,
# db=self._database_url.database,
# autocommit=True,
# )

async def disconnect(self) -> None:
pass
# assert self._pool is not None, "DatabaseBackend is not running"
# self._pool.close()
# await self._pool.wait_closed()
# self._pool = None

def connection(self) -> "SQLiteConnection":
return SQLiteConnection(self._pool, self._dialect)


class SQLitePool:
def __init__(self, url: DatabaseURL) -> None:
self._url = url

async def acquire(self) -> aiosqlite.Connection:
connection = aiosqlite.connect(
database=self._url.database, isolation_level=None
)
await connection.__aenter__()
return connection

async def release(self, connection: aiosqlite.Connection) -> None:
await connection.__aexit__(None, None, None)


class CompilationContext:
def __init__(self, context: ExecutionContext):
self.context = context


class SQLiteConnection(ConnectionBackend):
def __init__(self, pool: SQLitePool, dialect: Dialect):
self._pool = pool
self._dialect = dialect
self._connection = None

async def acquire(self) -> None:
assert self._connection is None, "Connection is already acquired"
self._connection = await self._pool.acquire()

async def release(self) -> None:
assert self._connection is not None, "Connection is not acquired"
await self._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
assert self._connection is not None, "Connection is not acquired"
query, args, context = self._compile(query)

async with self._connection.execute(query, args) as cursor:
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
for row in rows
]

async def fetch_one(self, query: ClauseElement) -> RowProxy:
assert self._connection is not None, "Connection is not acquired"
query, args, context = self._compile(query)

async with self._connection.execute(query, args) as cursor:
row = await cursor.fetchone()
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)

async def execute(self, query: ClauseElement, values: dict = None) -> None:
assert self._connection is not None, "Connection is not acquired"
if values is not None:
query = query.values(values)
query, args, context = self._compile(query)
cursor = await self._connection.execute(query, args)
await cursor.close()

async def execute_many(self, query: ClauseElement, values: list) -> None:
assert self._connection is not None, "Connection is not acquired"
for value in values:
await self.execute(query, value)

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query, args, context = self._compile(query)
cursor = await self._connection.cursor()
async with self._connection.execute(query, args) as cursor:
metadata = ResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)

def transaction(self) -> TransactionBackend:
return SQLiteTransaction(self)

def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, list, CompilationContext]:
compiled = query.compile(dialect=self._dialect)
args = []
for key, raw_val in compiled.construct_params().items():
if key in compiled._bind_processors:
val = compiled._bind_processors[key](raw_val)
else:
val = raw_val
args.append(val)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect
execution_context.result_column_struct = (
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
)

logger.debug(compiled.string, args)
return compiled.string, args, CompilationContext(execution_context)


class SQLiteTransaction(TransactionBackend):
def __init__(self, connection: SQLiteConnection):
self._connection = connection
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
if self._is_root:
cursor = await self._connection._connection.execute("BEGIN")
await cursor.close()
else:
id = str(uuid.uuid4()).replace("-", "_")
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
cursor = await self._connection._connection.execute(
f"SAVEPOINT {self._savepoint_name}"
)
await cursor.close()

async def commit(self) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
if self._is_root:
cursor = await self._connection._connection.execute("COMMIT")
await cursor.close()
else:
cursor = await self._connection._connection.execute(
f"RELEASE SAVEPOINT {self._savepoint_name}"
)
await cursor.close()

async def rollback(self) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
if self._is_root:
cursor = await self._connection._connection.execute("ROLLBACK")
await cursor.close()
else:
cursor = await self._connection._connection.execute(
f"ROLLBACK TO SAVEPOINT {self._savepoint_name}"
)
await cursor.close()
Loading