# Connections

> Connection management for SSE clients

In [None]:
#| default_exp core.connections

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
from typing import Dict, Optional, Any, Set, Callable
from datetime import datetime
from dataclasses import dataclass, field
from enum import Enum
from fasthtml.common import Div, Script, FT

## Core Classes

In [None]:
#| export
class ConnectionState(Enum):
    """States for SSE connections"""
    CONNECTING = "connecting"
    CONNECTED = "connected"
    DISCONNECTED = "disconnected"
    ERROR = "error"
    RECONNECTING = "reconnecting"

### Tests for ConnectionState

In [None]:
#| test
# Test ConnectionState enum
assert ConnectionState.CONNECTING.value == "connecting"
assert ConnectionState.CONNECTED.value == "connected"
assert ConnectionState.DISCONNECTED.value == "disconnected"
assert ConnectionState.ERROR.value == "error"
assert ConnectionState.RECONNECTING.value == "reconnecting"

# Test all enum members are present
expected_states = ["CONNECTING", "CONNECTED", "DISCONNECTED", "ERROR", "RECONNECTING"]
actual_states = [state.name for state in ConnectionState]
assert actual_states == expected_states

print("✓ ConnectionState enum tests passed")

✓ ConnectionState enum tests passed


In [None]:
#| export
@dataclass
class SSEConnection:
    """Represents a single SSE connection"""
    connection_id: str
    queue: asyncio.Queue
    connection_type: str = "global"
    state: ConnectionState = ConnectionState.CONNECTING
    metadata: Dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)
    last_activity: datetime = field(default_factory=datetime.now)
    message_count: int = 0
    
    async def send(
        self,
        data: Any,  # Data to send
        timeout: float = 1.0  # Timeout for the send operation
    ) -> bool:  # True if successful, False otherwise
        """Send data through the connection queue."""
        try:
            await asyncio.wait_for(self.queue.put(data), timeout=timeout)
            self.last_activity = datetime.now()
            self.message_count += 1
            return True
        except (asyncio.TimeoutError, asyncio.QueueFull):
            self.state = ConnectionState.ERROR
            return False
    
    async def heartbeat(
        self
    ) -> str:  # SSE formatted heartbeat message
        """Generate a heartbeat message."""
        self.last_activity = datetime.now()
        return f": heartbeat {self.connection_id} {self.last_activity.isoformat()}\n\n"
    
    def close(self):
        """Mark the connection as closed."""
        self.state = ConnectionState.DISCONNECTED
    
    def is_active(
        self
    ) -> bool:  # True if connection is active, False otherwise
        """Check if connection is active."""
        return self.state in [ConnectionState.CONNECTED, ConnectionState.CONNECTING]

### Tests for SSEConnection

In [None]:
#| test
import asyncio
# Test SSEConnection creation
queue = asyncio.Queue()
conn = SSEConnection(
    connection_id="test_conn",
    queue=queue,
    connection_type="custom",
    metadata={"user_id": "123", "role": "viewer"}
)

assert conn.connection_id == "test_conn"
assert conn.queue == queue
assert conn.connection_type == "custom"
assert conn.state == ConnectionState.CONNECTING
assert conn.metadata["user_id"] == "123"
assert conn.metadata["role"] == "viewer"
assert conn.message_count == 0
assert isinstance(conn.created_at, datetime)
assert isinstance(conn.last_activity, datetime)

print("✓ SSEConnection creation tests passed")

✓ SSEConnection creation tests passed


In [None]:
#| test
# Test SSEConnection methods
async def test_sse_connection_methods():
    queue = asyncio.Queue(maxsize=5)
    conn = SSEConnection(
        connection_id="test_methods",
        queue=queue
    )
    
    # Test send method
    success = await conn.send("test message", timeout=1.0)
    assert success == True
    assert conn.message_count == 1
    
    # Verify message was queued
    msg = await asyncio.wait_for(queue.get(), timeout=1.0)
    assert msg == "test message"
    
    # Test send with multiple messages
    for i in range(3):
        success = await conn.send(f"message_{i}")
        assert success == True
    assert conn.message_count == 4
    
    # Test is_active method
    conn.state = ConnectionState.CONNECTED
    assert conn.is_active() == True
    
    conn.state = ConnectionState.CONNECTING
    assert conn.is_active() == True
    
    conn.state = ConnectionState.DISCONNECTED
    assert conn.is_active() == False
    
    conn.state = ConnectionState.ERROR
    assert conn.is_active() == False
    
    # Test close method
    conn.state = ConnectionState.CONNECTED
    conn.close()
    assert conn.state == ConnectionState.DISCONNECTED
    assert conn.is_active() == False
    
    print("✓ SSEConnection methods tests passed")

await test_sse_connection_methods()

✓ SSEConnection methods tests passed


In [None]:
#| test
# Test SSEConnection heartbeat
async def test_heartbeat():
    queue = asyncio.Queue()
    conn = SSEConnection(
        connection_id="heartbeat_test",
        queue=queue
    )
    
    original_activity = conn.last_activity
    
    # Small delay to ensure time difference
    await asyncio.sleep(0.01)
    
    heartbeat_msg = await conn.heartbeat()
    
    # Check heartbeat message format
    assert isinstance(heartbeat_msg, str)
    assert ": heartbeat" in heartbeat_msg
    assert "heartbeat_test" in heartbeat_msg
    assert heartbeat_msg.endswith("\n\n")
    
    # Check that last_activity was updated
    assert conn.last_activity > original_activity
    
    print("✓ SSEConnection heartbeat tests passed")

await test_heartbeat()

✓ SSEConnection heartbeat tests passed


In [None]:
#| export
class ConnectionRegistry:
    """Registry to track and manage SSE connections"""
    
    def __init__(
        self,
        debug: bool = False  # Enable debug logging
    ):
        """Initialize the connection registry."""
        self.connections: Dict[str, SSEConnection] = {}
        self.connections_by_type: Dict[str, Set[str]] = {}
        self.lock = asyncio.Lock()
        self.debug = debug
        self._counter = 0
    
    async def add_connection(self,
                            conn_id: Optional[str] = None,  # Optional connection ID (auto-generated if not provided)
                            conn_type: str = "global",  # Type of connection (e.g., 'global', 'job', 'user')
                            queue_size: int = 100,  # Size of the message queue
                            metadata: Optional[Dict[str, Any]] = None # Optional metadata for the connection
                            ) -> SSEConnection: # The created SSEConnection
        """Add a new connection to the registry."""
        async with self.lock:
            if conn_id is None:
                self._counter += 1
                conn_id = f"{conn_type}_{self._counter}"
            
            queue = asyncio.Queue(maxsize=queue_size)
            connection = SSEConnection(
                connection_id=conn_id,
                queue=queue,
                connection_type=conn_type,
                metadata=metadata or {}
            )
            
            self.connections[conn_id] = connection
            
            # Track by type
            if conn_type not in self.connections_by_type:
                self.connections_by_type[conn_type] = set()
            self.connections_by_type[conn_type].add(conn_id)
            
            connection.state = ConnectionState.CONNECTED
            
            if self.debug:
                print(f"[ConnectionRegistry] Added connection: {conn_id} (type: {conn_type})")
            
            return connection
    
    async def remove_connection(
        self,
        conn_id: str  # Connection ID to remove
    ):
        """Remove a connection from the registry."""
        async with self.lock:
            if conn_id in self.connections:
                connection = self.connections[conn_id]
                connection.close()
                
                # Remove from type tracking
                conn_type = connection.connection_type
                if conn_type in self.connections_by_type:
                    self.connections_by_type[conn_type].discard(conn_id)
                    if not self.connections_by_type[conn_type]:
                        del self.connections_by_type[conn_type]
                
                del self.connections[conn_id]
                
                if self.debug:
                    print(f"[ConnectionRegistry] Removed connection: {conn_id}")
    
    def get_connection(
        self,
        conn_id: str  # Connection ID
    ) -> Optional[SSEConnection]:  # The connection if found, None otherwise
        """Get a specific connection."""
        return self.connections.get(conn_id)
    
    def get_connections(
        self,
        conn_type: Optional[str] = None  # Optional connection type to filter by
    ) -> list[SSEConnection]:  # List of connections
        """Get connections, optionally filtered by type."""
        if conn_type:
            conn_ids = self.connections_by_type.get(conn_type, set())
            return [self.connections[cid] for cid in conn_ids if cid in self.connections]
        return list(self.connections.values())
    
    def get_active_connections(
        self,
        conn_type: Optional[str] = None  # Optional connection type to filter by
    ) -> list[SSEConnection]:  # List of active connections
        """Get active connections."""
        connections = self.get_connections(conn_type)
        return [c for c in connections if c.is_active()]
    
    def get_stats(
        self
    ) -> Dict[str, Any]:  # Dictionary with connection statistics
        """Get registry statistics."""
        active_conns = self.get_active_connections()
        
        return {
            "total": len(self.connections),
            "active": len(active_conns),
            "by_type": {k: len(v) for k, v in self.connections_by_type.items()},
            "message_count": sum(c.message_count for c in self.connections.values())
        }

### Tests for ConnectionRegistry

In [None]:
#| test
# Test ConnectionRegistry initialization
registry = ConnectionRegistry(debug=False)

assert len(registry.connections) == 0
assert len(registry.connections_by_type) == 0
assert registry.debug == False
assert registry._counter == 0

stats = registry.get_stats()
assert stats["total"] == 0
assert stats["active"] == 0
assert stats["by_type"] == {}
assert stats["message_count"] == 0

print("✓ ConnectionRegistry initialization tests passed")

✓ ConnectionRegistry initialization tests passed


In [None]:
#| test
# Test adding and removing connections
async def test_registry_operations():
    registry = ConnectionRegistry()
    
    # Test adding connection without ID
    conn1 = await registry.add_connection(conn_type="user", queue_size=50)
    assert conn1.connection_id == "user_1"
    assert conn1.connection_type == "user"
    assert conn1.state == ConnectionState.CONNECTED
    assert len(registry.connections) == 1
    assert "user" in registry.connections_by_type
    assert "user_1" in registry.connections_by_type["user"]
    
    # Test adding connection with custom ID
    conn2 = await registry.add_connection(
        conn_id="custom_conn",
        conn_type="admin",
        metadata={"permissions": ["read", "write"]}
    )
    assert conn2.connection_id == "custom_conn"
    assert conn2.connection_type == "admin"
    assert conn2.metadata["permissions"] == ["read", "write"]
    assert len(registry.connections) == 2
    
    # Test adding multiple connections of same type
    conn3 = await registry.add_connection(conn_type="user")
    assert conn3.connection_id == "user_2"
    assert len(registry.connections_by_type["user"]) == 2
    
    # Test get_connection
    retrieved = registry.get_connection("custom_conn")
    assert retrieved == conn2
    assert registry.get_connection("non_existent") is None
    
    # Test get_connections by type
    user_conns = registry.get_connections("user")
    assert len(user_conns) == 2
    assert conn1 in user_conns
    assert conn3 in user_conns
    
    # Test get all connections
    all_conns = registry.get_connections()
    assert len(all_conns) == 3
    
    # Test removing connection
    await registry.remove_connection("user_1")
    assert len(registry.connections) == 2
    assert "user_1" not in registry.connections
    assert len(registry.connections_by_type["user"]) == 1
    
    # Remove remaining user connection
    await registry.remove_connection("user_2")
    assert "user" not in registry.connections_by_type  # Type should be removed when empty
    
    # Clean up
    await registry.remove_connection("custom_conn")
    assert len(registry.connections) == 0
    
    print("✓ ConnectionRegistry operations tests passed")

await test_registry_operations()

✓ ConnectionRegistry operations tests passed


In [None]:
#| test
# Test get_active_connections and stats
async def test_registry_active_and_stats():
    registry = ConnectionRegistry()
    
    # Add connections with different states
    conn1 = await registry.add_connection(conn_id="active1", conn_type="user")
    conn2 = await registry.add_connection(conn_id="active2", conn_type="user")
    conn3 = await registry.add_connection(conn_id="inactive", conn_type="admin")
    
    # Send some messages to track count
    await conn1.send("msg1")
    await conn1.send("msg2")
    await conn2.send("msg3")
    
    # Change state of one connection
    conn3.state = ConnectionState.DISCONNECTED
    
    # Test get_active_connections
    active_all = registry.get_active_connections()
    assert len(active_all) == 2  # Only conn1 and conn2 are active
    assert conn1 in active_all
    assert conn2 in active_all
    assert conn3 not in active_all
    
    # Test get_active_connections by type
    active_users = registry.get_active_connections("user")
    assert len(active_users) == 2
    
    active_admins = registry.get_active_connections("admin")
    assert len(active_admins) == 0  # conn3 is disconnected
    
    # Test statistics
    stats = registry.get_stats()
    assert stats["total"] == 3
    assert stats["active"] == 2
    assert stats["by_type"]["user"] == 2
    assert stats["by_type"]["admin"] == 1
    assert stats["message_count"] == 3  # Total messages sent
    
    # Clean up
    await registry.remove_connection("active1")
    await registry.remove_connection("active2")
    await registry.remove_connection("inactive")
    
    print("✓ ConnectionRegistry active connections and stats tests passed")

await test_registry_active_and_stats()

✓ ConnectionRegistry active connections and stats tests passed


## Helper Functions

In [None]:
#| export
def create_sse_element(endpoint: str,
                      element_id: Optional[str] = None,  # Optional element ID
                      swap_strategy: str = "message",  # HTMX swap strategy (message, innerHTML, outerHTML, etc.)
                      hidden: bool = False,  # Whether to hide the element
                      **attrs # Additional attributes for the element
                      ) -> Div:  # SSE-enabled Div element configured for HTMX
    """Create an SSE-enabled HTML element."""
    sse_attrs = {
        'hx_ext': 'sse',
        'sse_connect': endpoint,
        'sse_swap': swap_strategy
    }
    
    if element_id:
        sse_attrs['id'] = element_id
    
    if hidden:
        sse_attrs['style'] = 'display: none;'
    
    # Merge with additional attributes
    sse_attrs.update(attrs)
    
    return Div(**sse_attrs)

In [None]:
#| export
def cleanup_sse_on_unload(
) -> Script:  # Script element for cleanup
    "Create a script to clean up SSE connections on page unload."
    cleanup_code = """
    window.addEventListener('beforeunload', function() {
        // Close all SSE connections
        document.querySelectorAll('[sse-connect]').forEach(function(element) {
            var internalData = element['htmx-internal-data'];
            if (internalData && internalData.sseEventSource) {
                internalData.sseEventSource.close();
            }
        });
        
        // Trigger HTMX cleanup
        if (typeof htmx !== 'undefined') {
            htmx.findAll('[sse-connect]').forEach(function(element) {
                htmx.trigger(element, 'htmx:sseClose');
            });
        }
    });
    
    // Also handle page visibility changes
    document.addEventListener('visibilitychange', function() {
        if (document.hidden) {
            // Page is hidden - connections will be suspended by browser
            console.log('[SSE] Page hidden - connections suspended');
        } else {
            // Page is visible - connections may need reconnection
            console.log('[SSE] Page visible - checking connections');
        }
    });
    """
    
    return Script(cleanup_code)

In [None]:
#| export
def create_reconnection_script(check_interval: int = 5000,
                              max_retries: int = 5,  # Maximum number of reconnection attempts
                              debug: bool = False  # Enable debug logging
                              ) -> Script:  # Script element with reconnection logic
    """Create a script for automatic SSE reconnection."""
    debug_str = "true" if debug else "false"
    
    reconnect_code = f"""
    (function() {{
        var checkInterval = {check_interval};
        var maxRetries = {max_retries};
        var debug = {debug_str};
        var retryCount = {{}};
        
        function log(message) {{
            if (debug) {{
                console.log('[SSE Reconnect] ' + message);
            }}
        }}
        
        function checkConnection(element) {{
            var internalData = element['htmx-internal-data'];
            if (!internalData || !internalData.sseEventSource) {{
                return;
            }}
            
            var source = internalData.sseEventSource;
            var elementId = element.id || 'unknown';
            
            if (source.readyState === EventSource.CLOSED) {{
                if (!retryCount[elementId]) {{
                    retryCount[elementId] = 0;
                }}
                
                if (retryCount[elementId] < maxRetries) {{
                    retryCount[elementId]++;
                    log('Reconnecting ' + elementId + ' (attempt ' + retryCount[elementId] + ')');
                    htmx.trigger(element, 'htmx:sseReconnect');
                }} else {{
                    log('Max retries reached for ' + elementId);
                }}
            }} else if (source.readyState === EventSource.OPEN) {{
                // Reset retry count on successful connection
                retryCount[elementId] = 0;
            }}
        }}
        
        // Check connections periodically
        setInterval(function() {{
            document.querySelectorAll('[sse-connect]').forEach(checkConnection);
        }}, checkInterval);
        
        // Reset retry counts on visibility change
        document.addEventListener('visibilitychange', function() {{
            if (!document.hidden) {{
                log('Page visible - resetting retry counts');
                retryCount = {{}};
            }}
        }});
    }})();
    """
    
    return Script(reconnect_code)

In [None]:
#| export
def create_connection_manager_script(registry_endpoint: str = "/sse/connections",
                                    update_interval: int = 10000  # Interval for updating connection stats (in milliseconds)
                                    ) -> Script:  # Script element with connection management logic
    """Create a script to manage and monitor connections."""
    manager_code = f"""
    (function() {{
        var registryEndpoint = '{registry_endpoint}';
        var updateInterval = {update_interval};
        
        function updateConnectionStats() {{
            fetch(registryEndpoint)
                .then(response => response.json())
                .then(data => {{
                    // Dispatch custom event with connection stats
                    var event = new CustomEvent('sse:stats', {{
                        detail: data
                    }});
                    document.dispatchEvent(event);
                    
                    // Update any elements with data-sse-stats attribute
                    document.querySelectorAll('[data-sse-stats]').forEach(function(element) {{
                        var statType = element.getAttribute('data-sse-stats');
                        if (data[statType] !== undefined) {{
                            element.textContent = data[statType];
                        }}
                    }});
                }})
                .catch(error => {{
                    console.error('[SSE Manager] Failed to fetch stats:', error);
                }});
        }}
        
        // Initial update
        setTimeout(updateConnectionStats, 100);
        
        // Periodic updates
        setInterval(updateConnectionStats, updateInterval);
        
        // Listen for HTMX SSE events
        document.body.addEventListener('htmx:sseOpen', function(evt) {{
            console.log('[SSE Manager] Connection opened:', evt.detail.elt.id);
            updateConnectionStats();
        }});
        
        document.body.addEventListener('htmx:sseClose', function(evt) {{
            console.log('[SSE Manager] Connection closed:', evt.detail.elt.id);
            updateConnectionStats();
        }});
    }})();
    """
    
    return Script(manager_code)

### Tests for Helper Functions

In [None]:
#| test
# Test create_sse_element
element = create_sse_element(
    endpoint="/sse/test",
    element_id="sse-container",
    swap_strategy="innerHTML",
    hidden=False
)

assert isinstance(element, FT)
assert element.attrs.get('id') == "sse-container"
assert element.attrs.get('hx-ext') == "sse"
assert element.attrs.get('sse-connect') == "/sse/test"
assert element.attrs.get('sse-swap') == "innerHTML"
assert 'style' not in element.attrs  # Not hidden

# Test with hidden=True
hidden_element = create_sse_element(
    endpoint="/sse/hidden",
    hidden=True
)
assert hidden_element.attrs.get('style') == 'display: none;'

# Test with additional attributes
custom_element = create_sse_element(
    endpoint="/sse/custom",
    element_id="custom",
    class_="sse-stream",
    data_test="value"
)
assert custom_element.attrs.get('class-') == "sse-stream"
assert custom_element.attrs.get('data-test') == "value"

print("✓ create_sse_element tests passed")

✓ create_sse_element tests passed


In [None]:
#| test
# Test cleanup_sse_on_unload
cleanup_script = cleanup_sse_on_unload()

assert isinstance(cleanup_script, FT)

# Check that the script contains expected functionality
script_content = str(cleanup_script.children[0]) if cleanup_script.children else ""
assert "beforeunload" in script_content
assert "sseEventSource.close()" in script_content
assert "visibilitychange" in script_content
assert "htmx:sseClose" in script_content

print("✓ cleanup_sse_on_unload tests passed")

✓ cleanup_sse_on_unload tests passed


In [None]:
#| test
# Test create_reconnection_script
reconnect_script = create_reconnection_script(
    check_interval=3000,
    max_retries=3,
    debug=True
)

assert isinstance(reconnect_script, FT)

script_content = str(reconnect_script.children[0]) if reconnect_script.children else ""

# Check configuration values are present
assert "checkInterval = 3000" in script_content
assert "maxRetries = 3" in script_content
assert "debug = true" in script_content

# Check key functionality
assert "EventSource.CLOSED" in script_content
assert "htmx:sseReconnect" in script_content
assert "retryCount" in script_content
assert "[SSE Reconnect]" in script_content

# Test with debug=False
reconnect_no_debug = create_reconnection_script(debug=False)
no_debug_content = str(reconnect_no_debug.children[0]) if reconnect_no_debug.children else ""
assert "debug = false" in no_debug_content

print("✓ create_reconnection_script tests passed")

✓ create_reconnection_script tests passed


In [None]:
#| test
# Test create_connection_manager_script
manager_script = create_connection_manager_script(
    registry_endpoint="/api/connections",
    update_interval=5000
)

assert isinstance(manager_script, FT)

script_content = str(manager_script.children[0]) if manager_script.children else ""

# Check configuration values
assert "registryEndpoint = '/api/connections'" in script_content
assert "updateInterval = 5000" in script_content

# Check key functionality
assert "updateConnectionStats" in script_content
assert "sse:stats" in script_content
assert "data-sse-stats" in script_content
assert "htmx:sseOpen" in script_content
assert "htmx:sseClose" in script_content
assert "[SSE Manager]" in script_content

# Test with default values
default_manager = create_connection_manager_script()
default_content = str(default_manager.children[0]) if default_manager.children else ""
assert "registryEndpoint = '/sse/connections'" in default_content
assert "updateInterval = 10000" in default_content

print("✓ create_connection_manager_script tests passed")

✓ create_connection_manager_script tests passed


## Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()