# Towards an idempotent connection context manager

Note that a more general reusable tool for managing nested contexts was attempted in the  [Not entering/exiting contexts twice when nested](https://github.com/i2mint/lkj/issues/1) issue in `lkj`, but failed for now, so was closed. 

Following are more (data bases) specific implementation notes:

Say we have a connection class (or function/method) and a DAO using it:


In [8]:
class Connection:
    def __init__(self, uri='DFLT_URI'):
        self.uri = uri

    def __enter__(self):
        print(f"Opening connection to {self.uri}")
        return self

    def __exit__(self, *exc_info):
        print(f"Closing connection to {self.uri}")

    open = __enter__
    close = __exit__

class DAO1:
    def __init__(self, uri='DFLT_URI'):
        self.connection = Connection(uri)

    def read(self, k):
        with self.connection:
            print(f"    Reading {k}")

dao1 = DAO1()
dao1.read('bob')
dao1.read('alice')




Opening connection to DFLT_URI
    Reading bob
Closing connection to DFLT_URI
Opening connection to DFLT_URI
    Reading alice
Closing connection to DFLT_URI


If we now make our DAO a context manager itself, we get exactly the same behavior (of opening and closing the context at every read). 

In [7]:
class DAO2(DAO1):
    def __enter__(self):
        print(f"DAO entry...")
        self.connection.__enter__()
        return self

    def __exit__(self, *exc_info):
        print(f"... DAO exit")
        self.connection.__exit__(*exc_info)
        
dao2 = DAO2()
dao2.read('bob')
dao2.read('alice')

Opening connection to DFLT_URI
    Reading bob
Closing connection to DFLT_URI
Opening connection to DFLT_URI
    Reading alice
Closing connection to DFLT_URI


Additionally, we opening the connection twice in a row in the beginning, and close it twice in a row at the end. That could be a problem if our connection object doesn't like being opened when already open, or closed when already closed.

In [6]:
with dao2:
    dao2.read('bob')
    dao2.read('alice')


DAO entry...
Opening connection to DFLT_URI
Opening connection to DFLT_URI
    Reading bob
Closing connection to DFLT_URI
Opening connection to DFLT_URI
    Reading alice
Closing connection to DFLT_URI
... DAO exit
Closing connection to DFLT_URI


So really, we don't get much use out of our DAO's context manager here, yet.

What we'd like is for:

In [5]:
class DAO3:
    def __init__(self, uri='DFLT_URI'):
        self.connection = Connection(uri)
        self.connection_opened = False

    def read(self, k):
        if self.connection_opened:
            print(f"    Reading {k}")
        else:
            with self.connection:
                print(f"    Reading {k}")

    def __enter__(self):
        print(f"DAO entry...")
        self.connection.__enter__()
        self.connection_opened = True
        return self

    def __exit__(self, *exc_info):
        print(f"... DAO exit")
        self.connection_opened = False
        return self.connection.__exit__(*exc_info)

dao3 = DAO3()    
with dao3:
    dao3.read('bob')
    dao3.read('alice')

DAO entry...
Opening connection to DFLT_URI
    Reading bob
    Reading alice
... DAO exit
Closing connection to DFLT_URI


Good, we got what we wanted, but it's not super clean, nor is it reusable (we need this for other methods too (write, delete, etc.).

How can I separate concerns more (the DOA object and methods, the "don't open/close twice"  context manager concern, etc)

In [32]:
class Connection:
    def __init__(self, uri='DFLT_URI'):
        self.uri = uri
        self.is_open = False

    def open(self):
        if not self.is_open:
            print(f"Opening connection to {self.uri}")
            self.is_open = True

    def close(self):
        if self.is_open:
            print(f"Closing connection to {self.uri}")
            self.is_open = False

    __enter__ = open
    __exit__ = close
    
class FlexibleConnectionManager:
    def __init__(self, connection):
        self.connection = connection
        self.owns_connection = False

    def __enter__(self):
        if not self.connection.is_open:
            self.connection.open()
            self.owns_connection = True
        return self.connection

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.owns_connection and self.connection.is_open:
            self.connection.close()
            self.owns_connection = False

class DAO1:
    def __init__(self, uri='DFLT_URI'):
        self.connection = Connection(uri)

    def read(self, k):
        with FlexibleConnectionManager(self.connection):
            print(f"    Reading {k}")

    def __enter__(self):
        self.connection_manager = FlexibleConnectionManager(self.connection)
        return self.connection_manager.__enter__()

    def __exit__(self, *exc_info):
        self.connection_manager.__exit__(*exc_info)


In [23]:
__import__('ipytest').autoconfig()  # pip install ipytest

In [36]:
%%ipytest

def test_dao_1(capsys):
    dao1 = DAO1()    
    dao1.read('bob')
    dao1.read('alice')
    print("")
    with dao1:
        dao1.read('bob')
        dao1.read('alice')

    # Use capsys to capture print output
    captured = capsys.readouterr()

    # Assert against the captured output
    assert captured.out.splitlines() == [
        'Opening connection to DFLT_URI',
        '    Reading bob',
        'Closing connection to DFLT_URI',
        'Opening connection to DFLT_URI',
        '    Reading alice',
        'Closing connection to DFLT_URI',
        '',
        'Opening connection to DFLT_URI',
        '    Reading bob',
        '    Reading alice',
        'Closing connection to DFLT_URI'
    ]

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.00s[0m[0m


## Reusable Connection Manager

Good. But our setup here depends on a particular Connection class whose instances record whether they're open or not. 

We'd like recreate this behavior with any connection context manager (or even context manager "factory" (a function or class that returns a context manager instance).


In [25]:
class RefCountedConnectionManager:
    def __init__(self, get_connection):
        self.get_connection = get_connection
        self.connection = None
        self.ref_count = 0

    def __enter__(self):
        if self.ref_count == 0:
            self.connection = self.get_connection()
            self.connection.__enter__()
        self.ref_count += 1
        return self.connection

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.ref_count -= 1
        if self.ref_count == 0:
            self.connection.__exit__(exc_type, exc_val, exc_tb)
            self.connection = None

class DAO:
    def __init__(self, get_connection):
        self.connection_manager_factory = lambda: RefCountedConnectionManager(get_connection)
        self.connection_manager = self.connection_manager_factory()

    def read(self, k):
        with self.connection_manager:
            print(f"    Reading {k}")

    def __enter__(self):
        return self.connection_manager.__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.connection_manager.__exit__(exc_type, exc_val, exc_tb)
        # Reset the connection manager for potential reuse outside the context
        self.connection_manager = self.connection_manager_factory()


In [None]:
__import__('ipytest').autoconfig()  # pip install ipytest

In [26]:
%%ipytest

class _TestConnection:
    def __init__(self, uri='DFLT_URI'):
        self.uri = uri

    def __enter__(self):
        print(f"Opening connection to {self.uri}")
        return self

    def __exit__(self, *exc_info):
        print(f"Closing connection to {self.uri}")

    open = __enter__
    close = __exit__

def test_dao(capsys):
    dao = DAO(_TestConnection)    
    with dao:
        dao.read('bob')
        dao.read('alice')

    # Use capsys to capture print output
    captured = capsys.readouterr()

    # Assert against the captured output
    assert captured.out.splitlines() == [
        'Opening connection to DFLT_URI',
        '    Reading bob',
        '    Reading alice',
        'Closing connection to DFLT_URI'
    ]

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.00s[0m[0m


# Scrap

In [18]:
def track_entry_context_manager(CMClass):
    class TrackedEntryContextManager(CMClass):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self._is_entered = False  # Initialize the tracking attribute

        def __enter__(self):
            if not self.is_entered():
                self._is_entered = True
                return super().__enter__()
            else:
                return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            if self.is_entered():
                self._is_entered = False
                return super().__exit__(exc_type, exc_val, exc_tb)

        def is_entered(self):
            return self._is_entered
    
    # Return the dynamically created subclass
    return TrackedEntryContextManager

# Example usage
class MyContextManager:
    def __enter__(self):
        print("Entering context")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting context")

# Creating a subclass that tracks entry
TrackedMyContextManager = track_entry_context_manager(MyContextManager)

# Using the subclass
with TrackedMyContextManager() as ctx:
    print("Is entered?", ctx.is_entered())  # Should print True
print("Is entered after with block?", ctx.is_entered())  # Should print False




Entering context
Is entered? True
Exiting context
Is entered after with block? False


In [19]:
class IgnoreIfAlreadyOpen:
    def __init__(self, tracked_context):
        if not callable(getattr(tracked_context, 'is_entered', None)):
            raise AttributeError(
                f"The context is not tracked. Use `track_entry_context_manager` "
                f"to track it: {tracked_context}"
            )
        self.tracked_context = tracked_context
        self._entered_context = tracked_context

    def __enter__(self):
        if not self.tracked_context.is_entered():
            self._entered_context = self.tracked_context.__enter__()
        return self._entered_context
        
    def __exit__(self, *exc):
        return self.tracked_context.__exit__(*exc)
    

In [20]:
class W:
    def __init__(self, context):
        self.context = context
        self._entered = False

    def __enter__(self):
        if not self._entered:
            self._entered = True
            return self.context.__enter__()
        else:
            return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._entered:
            self._entered = False
            return self.context.__exit__(exc_type, exc_val, exc_tb)
        else:
            return None

    def __bool__(self):
        return self._entered
    
def do_stuff(context, x):
    print(f"Doing stuff with context and x={x}")

# Create a context manager C
class C:
    def __enter__(self):
        print("Entering context C")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting context C")

# Wrap C in the wrapper W
context = W(C())

# Call do_stuff twice with context
do_stuff(context, 1)
do_stuff(context, 2)

# Now use context in a with block
with context:
    do_stuff(context, 3)
    do_stuff(context, 4)


Doing stuff with context and x=1
Doing stuff with context and x=2
Entering context C
Doing stuff with context and x=3
Doing stuff with context and x=4
Exiting context C


In [21]:
import contextlib

class TrackEntry:
    """wraps a context manager and adds logic to track whether the context is 
    "open" or not. 
    This is achieved by setting a flag _is_open upon entering and exiting the context.
    """
    def __init__(self, context_manager):
        self.context_manager = context_manager
        self._is_open = False

    def __enter__(self):
        self.context_manager.__enter__()
        self._is_open = True
        return self.context_manager

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._is_open and not self.is_externally_managed():
            self.context_manager.__exit__(exc_type, exc_val, exc_tb)
            self._is_open = False

    def set_as_externally_managed(self):
        self._externally_managed = True
        return self
    
    def is_externally_managed(self):
        return getattr(self, '_externally_managed', False)


class IgnoreIfAlreadyOpen:
    def __init__(self, managed_context):
        self.managed_context = managed_context

    def __enter__(self):
        return self.managed_context.__enter__()

    def __exit__(self, *exc):
        return self.managed_context.__exit__(*exc)

# Usage
class YourContextManager:
    def __enter__(self):
        print("Entering C")
        # Initialize or open your resources here
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting C")
        # Clean up your resources here

c = YourContextManager()
managed_c = ManagedContext(c)

# This will open and close the context each time
def do_stuff(context, x):
    with Wrapper(context):
        print(f"Doing stuff with {x}")

do_stuff(managed_c, 1)
do_stuff(managed_c, 2)

print('')
# # This will open the context once and then reuse it
# with Wrapper(managed_c.set_as_externally_managed()):
#     do_stuff(managed_c, 1)
#     do_stuff(managed_c, 2)

# This will open the context once and then reuse it
with managed_c:
    do_stuff(managed_c, 1)
    do_stuff(managed_c, 2)



NameError: name 'ManagedContext' is not defined

In [17]:
class TestConnection:
    def __init__(self, uri='DFLT_URI'):
        self.uri = uri

    def __enter__(self):
        print(f"Opening connection to {self.uri}")
        return self

    def __exit__(self, *exc_info):
        print(f"Closing connection to {self.uri}")

    open = __enter__
    close = __exit__


dao = DAO(TestConnection)    
with dao:
    dao.read('bob')
    dao.read('alice')
print("")
dao.read('bob')
dao.read('alice')

Opening connection to DFLT_URI
Opening connection to DFLT_URI
    Reading bob
Closing connection to DFLT_URI
Opening connection to DFLT_URI
    Reading alice
Closing connection to DFLT_URI
Closing connection to DFLT_URI

Opening connection to DFLT_URI
    Reading bob
Closing connection to DFLT_URI
Opening connection to DFLT_URI
    Reading alice
Closing connection to DFLT_URI


In [None]:
# Opening connection to DFLT_URI
#     Reading bob
#     Reading alice
# Closing connection to DFLT_URI

# Opening connection to DFLT_URI
#     Reading bob
# Closing connection to DFLT_URI
# Opening connection to DFLT_URI
#     Reading alice
# Closing connection to DFLT_URI

In [14]:
from contextlib import AbstractContextManager

# print(isinstance(fp, AbstractContextManager))



True


In [None]:
from typing import Callable
from contextlib import AbstractContextManager

class DAO:
    def __init__(self, get_connection: Callable[[], AbstractContextManager]):
        self.get_connection = get_connection

    def read(self, k):
        with FlexibleConnectionManager(self.get_connection()):
            print(f"    Reading {k}")

    def __enter__(self):
        self.connection_manager = FlexibleConnectionManager(self.get_connection())
        return self.connection_manager.__enter__()

    def __exit__(self, *exc_info):
        self.connection_manager.__exit__(*exc_info)

    

# Scrap

In [None]:
def track_entry_context_manager(CMClass):
    class TrackedEntryContextManager(CMClass):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self._is_entered = False  # Initialize the tracking attribute

        def __enter__(self):
            if not self.is_entered():
                self._is_entered = True
                return super().__enter__()
            else:
                return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            if self.is_entered():
                self._is_entered = False
                return super().__exit__(exc_type, exc_val, exc_tb)

        def is_entered(self):
            return self._is_entered
    
    # Return the dynamically created subclass
    return TrackedEntryContextManager

# Example usage
class MyContextManager:
    def __enter__(self):
        print("Entering context")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting context")

# Creating a subclass that tracks entry
TrackedMyContextManager = track_entry_context_manager(MyContextManager)

# Using the subclass
with TrackedMyContextManager() as ctx:
    print("Is entered?", ctx.is_entered())  # Should print True
print("Is entered after with block?", ctx.is_entered())  # Should print False




Entering context
Is entered? True
Exiting context
Is entered after with block? False


In [None]:
class IgnoreIfAlreadyOpen:
    def __init__(self, tracked_context):
        if not callable(getattr(tracked_context, 'is_entered', None)):
            raise AttributeError(
                f"The context is not tracked. Use `track_entry_context_manager` "
                f"to track it: {tracked_context}"
            )
        self.tracked_context = tracked_context
        self._entered_context = tracked_context

    def __enter__(self):
        if not self.tracked_context.is_entered():
            self._entered_context = self.tracked_context.__enter__()
        return self._entered_context
        
    def __exit__(self, *exc):
        return self.tracked_context.__exit__(*exc)
    

In [None]:
class W:
    def __init__(self, context):
        self.context = context
        self._entered = False

    def __enter__(self):
        if not self._entered:
            self._entered = True
            return self.context.__enter__()
        else:
            return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._entered:
            self._entered = False
            return self.context.__exit__(exc_type, exc_val, exc_tb)
        else:
            return None

    def __bool__(self):
        return self._entered
    
def do_stuff(context, x):
    print(f"Doing stuff with context and x={x}")

# Create a context manager C
class C:
    def __enter__(self):
        print("Entering context C")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting context C")

# Wrap C in the wrapper W
context = W(C())

# Call do_stuff twice with context
do_stuff(context, 1)
do_stuff(context, 2)

# Now use context in a with block
with context:
    do_stuff(context, 3)
    do_stuff(context, 4)


Doing stuff with context and x=1
Doing stuff with context and x=2
Entering context C
Doing stuff with context and x=3
Doing stuff with context and x=4
Exiting context C


In [None]:
import contextlib

class TrackEntry:
    """wraps a context manager and adds logic to track whether the context is 
    "open" or not. 
    This is achieved by setting a flag _is_open upon entering and exiting the context.
    """
    def __init__(self, context_manager):
        self.context_manager = context_manager
        self._is_open = False

    def __enter__(self):
        self.context_manager.__enter__()
        self._is_open = True
        return self.context_manager

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._is_open and not self.is_externally_managed():
            self.context_manager.__exit__(exc_type, exc_val, exc_tb)
            self._is_open = False

    def set_as_externally_managed(self):
        self._externally_managed = True
        return self
    
    def is_externally_managed(self):
        return getattr(self, '_externally_managed', False)


class IgnoreIfAlreadyOpen:
    def __init__(self, managed_context):
        self.managed_context = managed_context

    def __enter__(self):
        return self.managed_context.__enter__()

    def __exit__(self, *exc):
        return self.managed_context.__exit__(*exc)

# Usage
class YourContextManager:
    def __enter__(self):
        print("Entering C")
        # Initialize or open your resources here
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting C")
        # Clean up your resources here

c = YourContextManager()
managed_c = ManagedContext(c)

# This will open and close the context each time
def do_stuff(context, x):
    with Wrapper(context):
        print(f"Doing stuff with {x}")

do_stuff(managed_c, 1)
do_stuff(managed_c, 2)

print('')
# # This will open the context once and then reuse it
# with Wrapper(managed_c.set_as_externally_managed()):
#     do_stuff(managed_c, 1)
#     do_stuff(managed_c, 2)

# This will open the context once and then reuse it
with managed_c:
    do_stuff(managed_c, 1)
    do_stuff(managed_c, 2)



Entering C
Doing stuff with 1
Exiting C
Entering C
Doing stuff with 2
Exiting C

Entering C
Doing stuff with 1
Exiting C
Entering C
Doing stuff with 2
Exiting C
