Skip to content

Commit

Permalink
support dialect+driver for default drivers (closes #395) (#396)
Browse files Browse the repository at this point in the history
Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>
  • Loading branch information
mhadam and aminalaee committed Sep 25, 2021
1 parent 612857d commit 3525ca5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ jobs:
run: "scripts/install"
- name: "Run tests"
env:
TEST_DATABASE_URLS: "sqlite:///testsuite, mysql://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite"
TEST_DATABASE_URLS: "sqlite:///testsuite, sqlite+aiosqlite:///testsuite, mysql://username:password@localhost:3306/testsuite, mysql+aiomysql://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite, postgresql+asyncpg://username:password@localhost:5432/testsuite"
run: "scripts/test"
8 changes: 7 additions & 1 deletion databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(

self._force_rollback = force_rollback

backend_str = self.SUPPORTED_BACKENDS[self.url.scheme]
backend_str = self._get_backend()
backend_cls = import_from_string(backend_str)
assert issubclass(backend_cls, DatabaseBackend)
self._backend = backend_cls(self.url, **self.options)
Expand Down Expand Up @@ -220,6 +220,12 @@ def force_rollback(self) -> typing.Iterator[None]:
finally:
self._force_rollback = initial

def _get_backend(self) -> str:
try:
return self.SUPPORTED_BACKENDS[self.url.scheme]
except KeyError:
return self.SUPPORTED_BACKENDS[self.url.dialect]


class Connection:
def __init__(self, backend: DatabaseBackend) -> None:
Expand Down
59 changes: 41 additions & 18 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import os
import re
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

import pytest
import sqlalchemy
Expand Down Expand Up @@ -78,14 +78,18 @@ def process_result_value(self, value, dialect):
)


@pytest.fixture(autouse=True, scope="module")
@pytest.fixture(autouse=True, scope="function")
def create_test_database():
# Create test databases with tables creation
for url in DATABASE_URLS:
database_url = DatabaseURL(url)
if database_url.scheme == "mysql":
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme == "postgresql+aiopg":
elif database_url.scheme in [
"postgresql+aiopg",
"sqlite+aiosqlite",
"postgresql+asyncpg",
]:
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.create_all(engine)
Expand All @@ -96,9 +100,13 @@ def create_test_database():
# Drop test databases
for url in DATABASE_URLS:
database_url = DatabaseURL(url)
if database_url.scheme == "mysql":
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme == "postgresql+aiopg":
elif database_url.scheme in [
"postgresql+aiopg",
"sqlite+aiosqlite",
"postgresql+asyncpg",
]:
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.drop_all(engine)
Expand Down Expand Up @@ -478,9 +486,12 @@ async def test_transaction_commit_serializable(database_url):

database_url = DatabaseURL(database_url)

if database_url.scheme != "postgresql":
if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]:
pytest.skip("Test (currently) only supports asyncpg")

if database_url.scheme == "postgresql+asyncpg":
database_url = database_url.replace(driver=None)

def insert_independently():
engine = sqlalchemy.create_engine(str(database_url))
conn = engine.connect()
Expand Down Expand Up @@ -844,26 +855,34 @@ async def test_queries_with_expose_backend_connection(database_url):
raw_connection = connection.raw_connection

# Insert query
if database.url.scheme in ["mysql", "postgresql+aiopg"]:
if database.url.scheme in [
"mysql",
"mysql+aiomysql",
"postgresql+aiopg",
]:
insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)"
else:
insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)"

# execute()
values = ("example1", True)

if database.url.scheme in ["mysql", "postgresql+aiopg"]:
if database.url.scheme in [
"mysql",
"mysql+aiomysql",
"postgresql+aiopg",
]:
cursor = await raw_connection.cursor()
await cursor.execute(insert_query, values)
elif database.url.scheme == "postgresql":
elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
await raw_connection.execute(insert_query, *values)
elif database.url.scheme == "sqlite":
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
await raw_connection.execute(insert_query, values)

# execute_many()
values = [("example2", False), ("example3", True)]

if database.url.scheme == "mysql":
if database.url.scheme in ["mysql", "mysql+aiomysql"]:
cursor = await raw_connection.cursor()
await cursor.executemany(insert_query, values)
elif database.url.scheme == "postgresql+aiopg":
Expand All @@ -878,13 +897,17 @@ async def test_queries_with_expose_backend_connection(database_url):
select_query = "SELECT notes.id, notes.text, notes.completed FROM notes"

# fetch_all()
if database.url.scheme in ["mysql", "postgresql+aiopg"]:
if database.url.scheme in [
"mysql",
"mysql+aiomysql",
"postgresql+aiopg",
]:
cursor = await raw_connection.cursor()
await cursor.execute(select_query)
results = await cursor.fetchall()
elif database.url.scheme == "postgresql":
elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
results = await raw_connection.fetch(select_query)
elif database.url.scheme == "sqlite":
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
results = await raw_connection.execute_fetchall(select_query)

assert len(results) == 3
Expand All @@ -897,7 +920,7 @@ async def test_queries_with_expose_backend_connection(database_url):
assert results[2][2] == True

# fetch_one()
if database.url.scheme == "postgresql":
if database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
result = await raw_connection.fetchrow(select_query)
else:
cursor = await raw_connection.cursor()
Expand Down Expand Up @@ -1065,8 +1088,8 @@ async def test_posgres_interface(database_url):
"""
database_url = DatabaseURL(database_url)

if database_url.scheme != "postgresql":
pytest.skip("Test is only for postgresql")
if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]:
pytest.skip("Test is only for asyncpg")

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
Expand Down
16 changes: 12 additions & 4 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def create_test_database():
# Create test databases
for url in DATABASE_URLS:
database_url = DatabaseURL(url)
if database_url.scheme == "mysql":
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme == "postgresql+aiopg":
elif database_url.scheme in [
"postgresql+aiopg",
"sqlite+aiosqlite",
"postgresql+asyncpg",
]:
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.create_all(engine)
Expand All @@ -41,9 +45,13 @@ def create_test_database():
# Drop test databases
for url in DATABASE_URLS:
database_url = DatabaseURL(url)
if database_url.scheme == "mysql":
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme == "postgresql+aiopg":
elif database_url.scheme in [
"postgresql+aiopg",
"sqlite+aiosqlite",
"postgresql+asyncpg",
]:
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.drop_all(engine)
Expand Down

0 comments on commit 3525ca5

Please sign in to comment.