diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index 498bbbaa..779585be 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -145,6 +145,8 @@ def reset(self): # as UnitOfWork objects. self.units_of_work = {} + self.session_connection_map = {} + self.metadata = None def create_transaction_model(self): @@ -305,6 +307,8 @@ def unit_of_work(self, session): :param session: SQLAlchemy session object """ conn = session.connection() + if conn not in self.session_connection_map.values(): + self.session_connection_map[session] = conn if conn in self.units_of_work: return self.units_of_work[conn] @@ -352,10 +356,11 @@ def clear(self, session): """ if session.transaction.nested: return - conn = session.bind + conn = self.session_connection_map.pop(session, None) if conn in self.units_of_work: uow = self.units_of_work[conn] uow.reset(session) + del self.units_of_work[conn] def append_association_operation(self, conn, table_name, params, op): """ diff --git a/tests/__init__.py b/tests/__init__.py index fd8e6aef..310e9bb6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -121,6 +121,10 @@ def drop_tables(self): self.Model.metadata.drop_all(self.connection) def teardown_method(self, method): + self.session.rollback() + uow_leaks = versioning_manager.units_of_work + session_map_leaks = versioning_manager.session_connection_map + remove_versioning() QueryPool.queries = [] versioning_manager.reset() @@ -131,6 +135,9 @@ def teardown_method(self, method): self.engine.dispose() self.connection.close() + assert not uow_leaks + assert not session_map_leaks + def create_models(self): class Article(self.Model): __tablename__ = 'article' diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 2193b04c..f5780f89 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -24,6 +24,13 @@ def test_multiple_connections(self): article.versions[-1].transaction_id ) + def test_connection_binded_to_engine(self): + self.session2 = Session(bind=self.engine) + article = self.Article(name=u'Session1 article') + self.session2.add(article) + self.session2.commit() + assert article.versions[-1].transaction_id + def test_manual_transaction_creation(self): uow = versioning_manager.unit_of_work(self.session) transaction = uow.create_transaction(self.session)