diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 700a2490..da12a51e 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -47,5 +47,5 @@ jobs: run: "scripts/install" - name: "Run tests" env: - TEST_DATABASE_URLS: "sqlite:///testsuite, mysql://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite" + TEST_DATABASE_URLS: "sqlite:///testsuite, sqlite+aiosqlite:///testsuite, mysql://username:password@localhost:3306/testsuite, mysql+aiomysql://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite, postgresql+asyncpg://username:password@localhost:5432/testsuite" run: "scripts/test" diff --git a/databases/core.py b/databases/core.py index 727802d4..95040547 100644 --- a/databases/core.py +++ b/databases/core.py @@ -62,7 +62,7 @@ def __init__( self._force_rollback = force_rollback - backend_str = self.SUPPORTED_BACKENDS[self.url.scheme] + backend_str = self._get_backend() backend_cls = import_from_string(backend_str) assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) @@ -220,6 +220,12 @@ def force_rollback(self) -> typing.Iterator[None]: finally: self._force_rollback = initial + def _get_backend(self) -> str: + try: + return self.SUPPORTED_BACKENDS[self.url.scheme] + except KeyError: + return self.SUPPORTED_BACKENDS[self.url.dialect] + class Connection: def __init__(self, backend: DatabaseBackend) -> None: diff --git a/tests/test_databases.py b/tests/test_databases.py index 445a2453..8fde4387 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -4,7 +4,7 @@ import functools import os import re -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import pytest import sqlalchemy @@ -78,14 +78,18 @@ def process_result_value(self, value, dialect): ) -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="function") def create_test_database(): # Create test databases with tables creation for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme in ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in [ + "postgresql+aiopg", + "sqlite+aiosqlite", + "postgresql+asyncpg", + ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -96,9 +100,13 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme in ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in [ + "postgresql+aiopg", + "sqlite+aiosqlite", + "postgresql+asyncpg", + ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) @@ -478,9 +486,12 @@ async def test_transaction_commit_serializable(database_url): database_url = DatabaseURL(database_url) - if database_url.scheme != "postgresql": + if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]: pytest.skip("Test (currently) only supports asyncpg") + if database_url.scheme == "postgresql+asyncpg": + database_url = database_url.replace(driver=None) + def insert_independently(): engine = sqlalchemy.create_engine(str(database_url)) conn = engine.connect() @@ -844,7 +855,11 @@ async def test_queries_with_expose_backend_connection(database_url): raw_connection = connection.raw_connection # Insert query - if database.url.scheme in ["mysql", "postgresql+aiopg"]: + if database.url.scheme in [ + "mysql", + "mysql+aiomysql", + "postgresql+aiopg", + ]: insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" else: insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)" @@ -852,18 +867,22 @@ async def test_queries_with_expose_backend_connection(database_url): # execute() values = ("example1", True) - if database.url.scheme in ["mysql", "postgresql+aiopg"]: + if database.url.scheme in [ + "mysql", + "mysql+aiomysql", + "postgresql+aiopg", + ]: cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) - elif database.url.scheme == "postgresql": + elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: await raw_connection.execute(insert_query, *values) - elif database.url.scheme == "sqlite": + elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: await raw_connection.execute(insert_query, values) # execute_many() values = [("example2", False), ("example3", True)] - if database.url.scheme == "mysql": + if database.url.scheme in ["mysql", "mysql+aiomysql"]: cursor = await raw_connection.cursor() await cursor.executemany(insert_query, values) elif database.url.scheme == "postgresql+aiopg": @@ -878,13 +897,17 @@ async def test_queries_with_expose_backend_connection(database_url): select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" # fetch_all() - if database.url.scheme in ["mysql", "postgresql+aiopg"]: + if database.url.scheme in [ + "mysql", + "mysql+aiomysql", + "postgresql+aiopg", + ]: cursor = await raw_connection.cursor() await cursor.execute(select_query) results = await cursor.fetchall() - elif database.url.scheme == "postgresql": + elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: results = await raw_connection.fetch(select_query) - elif database.url.scheme == "sqlite": + elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: results = await raw_connection.execute_fetchall(select_query) assert len(results) == 3 @@ -897,7 +920,7 @@ async def test_queries_with_expose_backend_connection(database_url): assert results[2][2] == True # fetch_one() - if database.url.scheme == "postgresql": + if database.url.scheme in ["postgresql", "postgresql+asyncpg"]: result = await raw_connection.fetchrow(select_query) else: cursor = await raw_connection.cursor() @@ -1065,8 +1088,8 @@ async def test_posgres_interface(database_url): """ database_url = DatabaseURL(database_url) - if database_url.scheme != "postgresql": - pytest.skip("Test is only for postgresql") + if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]: + pytest.skip("Test is only for asyncpg") async with Database(database_url) as database: async with database.transaction(force_rollback=True): diff --git a/tests/test_integration.py b/tests/test_integration.py index c0cef2db..f53471f6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -28,9 +28,13 @@ def create_test_database(): # Create test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme in ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in [ + "postgresql+aiopg", + "sqlite+aiosqlite", + "postgresql+asyncpg", + ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -41,9 +45,13 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme in ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in [ + "postgresql+aiopg", + "sqlite+aiosqlite", + "postgresql+asyncpg", + ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine)