diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 1f0a2eee90beae..394fceb69b788a 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -10,8 +10,7 @@ connection - Tests for the different SQL flavors (flavor specific type conversions) - Tests for the sqlalchemy mode: `_TestSQLAlchemy` is the base class with - common methods, `_TestSQLAlchemyConn` tests the API with a SQLAlchemy - Connection object. The different tested flavors (sqlite3, MySQL, + common methods. The different tested flavors (sqlite3, MySQL, PostgreSQL) derive from the base class - Tests for the fallback mode (`TestSQLiteFallback`) @@ -664,38 +663,47 @@ class MixInBase: def teardown_method(self): # if setup fails, there may not be a connection to close. if hasattr(self, "conn"): - for tbl in self._get_all_tables(): - self.drop_table(tbl) - self._close_conn() + self.conn.close() + # use a fresh connection to ensure we can drop all tables. + try: + conn = self.connect() + except (sqlalchemy.exc.OperationalError, sqlite3.OperationalError): + pass + else: + with conn: + for tbl in self._get_all_tables(conn): + self.drop_table(tbl, conn) class SQLiteMixIn(MixInBase): - def drop_table(self, table_name): - self.conn.execute( - f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}" - ) - self.conn.commit() + def connect(self): + return sqlite3.connect(":memory:") - def _get_all_tables(self): - c = self.conn.execute("SELECT name FROM sqlite_master WHERE type='table'") - return [table[0] for table in c.fetchall()] + def drop_table(self, table_name, conn): + conn.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}") + conn.commit() - def _close_conn(self): - self.conn.close() + def _get_all_tables(self, conn): + c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + return [table[0] for table in c.fetchall()] class SQLAlchemyMixIn(MixInBase): - def drop_table(self, table_name): - sql.SQLDatabase(self.conn).drop_table(table_name) + @classmethod + def teardown_class(cls): + cls.engine.dispose() - def _get_all_tables(self): - from sqlalchemy import inspect + def connect(self): + return self.engine.connect() - return inspect(self.conn).get_table_names() + def drop_table(self, table_name, conn): + with conn.begin(): + sql.SQLDatabase(conn).drop_table(table_name) + + def _get_all_tables(self, conn): + from sqlalchemy import inspect - def _close_conn(self): - # https://docs.sqlalchemy.org/en/14/core/connections.html#engine-disposal - self.conn.dispose() + return inspect(conn).get_table_names() class PandasSQLTest: @@ -704,20 +712,14 @@ class PandasSQLTest: """ - @pytest.fixture def load_iris_data(self, iris_path): - if not hasattr(self, "conn"): - self.setup_connect() - self.drop_table("iris") + self.drop_table("iris", self.conn) if isinstance(self.conn, sqlite3.Connection): create_and_load_iris_sqlite3(self.conn, iris_path) else: create_and_load_iris(self.conn, iris_path, self.flavor) - @pytest.fixture def load_types_data(self, types_data): - if not hasattr(self, "conn"): - self.setup_connect() if self.flavor != "postgresql": for entry in types_data: entry.pop("DateColWithTz") @@ -745,13 +747,13 @@ def _read_sql_iris_no_parameter_with_percent(self): check_iris_frame(iris_frame) def _to_sql_empty(self, test_frame1): - self.drop_table("test_frame1") + self.drop_table("test_frame1", self.conn) assert self.pandasSQL.to_sql(test_frame1.iloc[:0], "test_frame1") == 0 def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): """`to_sql` with the `engine` param""" # mostly copied from this class's `_to_sql()` method - self.drop_table("test_frame1") + self.drop_table("test_frame1", self.conn) assert ( self.pandasSQL.to_sql( @@ -766,10 +768,10 @@ def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): assert num_rows == num_entries # Nuke table - self.drop_table("test_frame1") + self.drop_table("test_frame1", self.conn) def _roundtrip(self, test_frame1): - self.drop_table("test_frame_roundtrip") + self.drop_table("test_frame_roundtrip", self.conn) assert self.pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") == 4 result = self.pandasSQL.read_query("SELECT * FROM test_frame_roundtrip") @@ -855,11 +857,13 @@ class _TestSQLApi(PandasSQLTest): flavor = "sqlite" mode: str - def setup_connect(self): - self.conn = self.connect() - @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, iris_path, types_data): + self.conn = self.connect() + if not isinstance(self.conn, sqlite3.Connection): + self.conn.begin() + self.load_iris_data(iris_path) + self.load_types_data(types_data) self.load_test_data_and_sql() def load_test_data_and_sql(self): @@ -1287,7 +1291,8 @@ def test_escaped_table_name(self): tm.assert_frame_equal(res, df) -class _TestSQLApiEngine(SQLAlchemyMixIn, _TestSQLApi): +@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed") +class TestSQLApi(SQLAlchemyMixIn, _TestSQLApi): """ Test the public API as it would be used directly @@ -1299,8 +1304,9 @@ class _TestSQLApiEngine(SQLAlchemyMixIn, _TestSQLApi): flavor = "sqlite" mode = "sqlalchemy" - def connect(self): - return sqlalchemy.create_engine("sqlite:///:memory:") + @classmethod + def setup_class(cls): + cls.engine = sqlalchemy.create_engine("sqlite:///:memory:") def test_read_table_columns(self, test_frame1): # test columns argument in read_table @@ -1488,34 +1494,6 @@ def test_column_with_percentage(self): tm.assert_frame_equal(res, df) -class _EngineToConnMixin: - """ - A mixin that causes setup_connect to create a conn rather than an engine. - """ - - @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): - super().load_test_data_and_sql() - engine = self.conn - conn = engine.connect() - self.__tx = conn.begin() - self.pandasSQL = sql.SQLDatabase(conn) - self.__engine = engine - self.conn = conn - - yield - - self.__tx.rollback() - self.conn.close() - self.conn = self.__engine - self.pandasSQL = sql.SQLDatabase(self.__engine) - - -@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed") -class TestSQLApiConn(_EngineToConnMixin, _TestSQLApiEngine): - pass - - class TestSQLiteFallbackApi(SQLiteMixIn, _TestSQLApi): """ Test the public sqlite connection fallback API @@ -1607,6 +1585,7 @@ def test_sqlite_type_mapping(self): # -- Database flavor specific tests +@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="SQLAlchemy not installed") class _TestSQLAlchemy(SQLAlchemyMixIn, PandasSQLTest): """ Base class for testing the sqlalchemy backend. @@ -1619,43 +1598,29 @@ class _TestSQLAlchemy(SQLAlchemyMixIn, PandasSQLTest): flavor: str @classmethod - @pytest.fixture(autouse=True, scope="class") def setup_class(cls): - cls.setup_import() cls.setup_driver() - conn = cls.conn = cls.connect() - conn.connect() - - def load_test_data_and_sql(self): - pass + cls.setup_engine() @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): - pass - - @classmethod - def setup_import(cls): - # Skip this test if SQLAlchemy not available - if not SQLALCHEMY_INSTALLED: - pytest.skip("SQLAlchemy not installed") + def setup_method(self, iris_path, types_data): + try: + self.conn = self.engine.connect() + self.conn.begin() + self.pandasSQL = sql.SQLDatabase(self.conn) + except sqlalchemy.exc.OperationalError: + pytest.skip(f"Can't connect to {self.flavor} server") + self.load_iris_data(iris_path) + self.load_types_data(types_data) @classmethod def setup_driver(cls): raise NotImplementedError() @classmethod - def connect(cls): + def setup_engine(cls): raise NotImplementedError() - def setup_connect(self): - try: - self.conn = self.connect() - self.pandasSQL = sql.SQLDatabase(self.conn) - # to test if connection can be made: - self.conn.connect() - except sqlalchemy.exc.OperationalError: - pytest.skip(f"Can't connect to {self.flavor} server") - def test_read_sql_parameter(self): self._read_sql_iris_parameter() @@ -2041,6 +2006,7 @@ def _get_index_columns(self, tbl_name): def test_to_sql_save_index(self): self._to_sql_save_index() + @pytest.mark.xfail(reason="Nested transactions rollbacks don't work with Pandas") def test_transactions(self): self._transaction_test() @@ -2055,7 +2021,7 @@ def test_get_schema_create_table(self, test_frame3): create_sql = sql.get_schema(test_frame3, tbl, con=self.conn) blank_test_df = test_frame3.iloc[:0] - self.drop_table(tbl) + self.drop_table(tbl, self.conn) create_sql = text(create_sql) if isinstance(self.conn, Engine): with self.conn.connect() as conn: @@ -2065,7 +2031,7 @@ def test_get_schema_create_table(self, test_frame3): self.conn.execute(create_sql) returned_df = sql.read_sql_table(tbl, self.conn) tm.assert_frame_equal(returned_df, blank_test_df, check_index_type=False) - self.drop_table(tbl) + self.drop_table(tbl, self.conn) def test_dtype(self): from sqlalchemy import ( @@ -2301,13 +2267,7 @@ def test_get_engine_auto_error_message(self): # TODO(GH#36893) fill this in when we add more engines -class _TestSQLAlchemyConn(_EngineToConnMixin, _TestSQLAlchemy): - @pytest.mark.xfail(reason="Nested transactions rollbacks don't work with Pandas") - def test_transactions(self): - super().test_transactions() - - -class _TestSQLiteAlchemy: +class TestSQLiteAlchemy(_TestSQLAlchemy): """ Test the sqlalchemy backend against an in-memory sqlite database. @@ -2316,8 +2276,8 @@ class _TestSQLiteAlchemy: flavor = "sqlite" @classmethod - def connect(cls): - return sqlalchemy.create_engine("sqlite:///:memory:") + def setup_engine(cls): + cls.engine = sqlalchemy.create_engine("sqlite:///:memory:") @classmethod def setup_driver(cls): @@ -2390,7 +2350,8 @@ class Test(BaseModel): assert list(df.columns) == ["id", "string_column"] -class _TestMySQLAlchemy: +@pytest.mark.db +class TestMySQLAlchemy(_TestSQLAlchemy): """ Test the sqlalchemy backend against an MySQL database. @@ -2400,8 +2361,8 @@ class _TestMySQLAlchemy: port = 3306 @classmethod - def connect(cls): - return sqlalchemy.create_engine( + def setup_engine(cls): + cls.engine = sqlalchemy.create_engine( f"mysql+{cls.driver}://root@localhost:{cls.port}/pandas", connect_args=cls.connect_args, ) @@ -2416,7 +2377,8 @@ def test_default_type_conversion(self): pass -class _TestPostgreSQLAlchemy: +@pytest.mark.db +class TestPostgreSQLAlchemy(_TestSQLAlchemy): """ Test the sqlalchemy backend against an PostgreSQL database. @@ -2426,8 +2388,8 @@ class _TestPostgreSQLAlchemy: port = 5432 @classmethod - def connect(cls): - return sqlalchemy.create_engine( + def setup_engine(cls): + cls.engine = sqlalchemy.create_engine( f"postgresql+{cls.driver}://postgres:postgres@localhost:{cls.port}/pandas" ) @@ -2525,20 +2487,6 @@ def test_schema_support(self): tm.assert_frame_equal(res1, res2) -@pytest.mark.db -class TestMySQLAlchemyConn(_TestMySQLAlchemy, _TestSQLAlchemyConn): - pass - - -@pytest.mark.db -class TestPostgreSQLAlchemyConn(_TestPostgreSQLAlchemy, _TestSQLAlchemyConn): - pass - - -class TestSQLiteAlchemyConn(_TestSQLiteAlchemy, _TestSQLAlchemyConn): - pass - - # ----------------------------------------------------------------------------- # -- Test Sqlite / MySQL fallback @@ -2551,15 +2499,11 @@ class TestSQLiteFallback(SQLiteMixIn, PandasSQLTest): flavor = "sqlite" - @classmethod - def connect(cls): - return sqlite3.connect(":memory:") - - def setup_connect(self): - self.conn = self.connect() - @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, iris_path, types_data): + self.conn = self.connect() + self.load_iris_data(iris_path) + self.load_types_data(types_data) self.pandasSQL = sql.SQLiteDatabase(self.conn) def test_read_sql_parameter(self):