diff --git a/tests/test_databases.py b/tests/test_databases.py index f6ef67f9..ad58cbea 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -83,11 +83,11 @@ def create_test_database(): # Create test databases with tables creation for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme == ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in ["postgresql+aiopg", "postgresql+asyncpg"]: url = str(database_url.replace(driver=None)) - elif database_url.scheme == "postgresql+asyncpg": + elif database_url.scheme in ["sqlite", "sqlite+aiosqlite"]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -98,11 +98,11 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme == ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in ["postgresql+aiopg", "postgresql+asyncpg"]: url = str(database_url.replace(driver=None)) - elif database_url.scheme == "postgresql+asyncpg": + elif database_url.scheme in ["sqlite", "sqlite+aiosqlite"]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) diff --git a/tests/test_integration.py b/tests/test_integration.py index c0cef2db..032c3842 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -28,9 +28,11 @@ def create_test_database(): # Create test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme == ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in ["postgresql+aiopg", "postgresql+asyncpg"]: + url = str(database_url.replace(driver=None)) + elif database_url.scheme in ["sqlite", "sqlite+aiosqlite"]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -41,9 +43,11 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.scheme == "mysql": + if database_url.scheme == ["mysql", "mysql+aiomysql"]: url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": + elif database_url.scheme in ["postgresql+aiopg", "postgresql+asyncpg"]: + url = str(database_url.replace(driver=None)) + elif database_url.scheme in ["sqlite", "sqlite+aiosqlite"]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine)