# broadcast

> SSE broadcast infrastructure for managing connections and broadcasting updates to multiple clients

In [None]:
#| default_exp core.broadcast

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
from typing import Set, Dict, Any, Optional, Deque, List, Callable, Awaitable
from collections import deque
from datetime import datetime
import json
from dataclasses import dataclass, field
from enum import Enum

## Broadcast Message Types

In [None]:
#| export
class BroadcastEventType(Enum):
    """Standard broadcast event types"""
    UPDATE = "update"
    CREATE = "create"
    DELETE = "delete"
    REFRESH = "refresh"
    STATUS = "status"
    ERROR = "error"
    CUSTOM = "custom"

In [None]:
#| export
@dataclass
class BroadcastMessage:
    """Structured broadcast message"""
    type: str
    data: Dict[str, Any]
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    target_ids: Optional[List[str]] = None  # Specific element IDs to update
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(
        self
    ) -> Dict[str, Any]:  # Dictionary representation of the message
        """Convert to dictionary for JSON serialization"""
        result = {
            "type": self.type,
            "data": self.data,
            "timestamp": self.timestamp
        }
        if self.target_ids:
            result["target_ids"] = self.target_ids
        if self.metadata:
            result["metadata"] = self.metadata
        return result
    
    def to_json(
        self
    ) -> str:  # JSON string representation of the message
        """Convert to JSON string"""
        return json.dumps(self.to_dict())

## SSE Broadcast Manager

In [None]:
#| export
class SSEBroadcastManager:
    """
    Manages SSE connections and broadcasting across multiple clients. This class provides a centralized way to manage Server-Sent Events connections and broadcast messages to all connected clients, enabling real-time updates and cross-tab synchronization.
    """
    
    def __init__(self, 
                 history_size: int = 50,  # Number of messages to keep in history (for debugging/replay)
                 queue_size: int = 100,  # Maximum size of each client's message queue
                 timeout: float = 0.1,  # Timeout for queue operations in seconds
                 debug: bool = False): # Enable debug logging
        """
        Initialize the SSE Broadcast Manager.
        """
        self.connections: Set[asyncio.Queue] = set()
        self.lock = asyncio.Lock()
        self.history: Deque[BroadcastMessage] = deque(maxlen=history_size)
        self.queue_size = queue_size
        self.timeout = timeout
        self.debug = debug
        self._client_metadata: Dict[asyncio.Queue, Dict[str, Any]] = {}
        self._message_handlers: List[Callable[[BroadcastMessage], Awaitable[Optional[BroadcastMessage]]]] = []
    
    def _log(
        self,
        message: str  # Log message to print if debug is enabled
    ):
        """Internal logging method"""
        if self.debug:
            print(f"[SSEBroadcastManager] {message}")
    
    async def register_connection(self, 
                                  queue: Optional[asyncio.Queue] = None,  # Optional pre-existing queue (creates new if None)
                                  metadata: Optional[Dict[str, Any]] = None # Optional metadata for this connection
                                 ) -> asyncio.Queue: # The queue for this connection
        """
        Register a new client connection.            
        """
        if queue is None:
            queue = asyncio.Queue(maxsize=self.queue_size)
        
        async with self.lock:
            self.connections.add(queue)
            if metadata:
                self._client_metadata[queue] = metadata
            self._log(f"Connection registered. Total connections: {len(self.connections)}")
        
        return queue
    
    async def unregister_connection(
        self,
        queue: asyncio.Queue  # The queue to unregister
    ):
        """
        Remove a client connection.
        """
        async with self.lock:
            self.connections.discard(queue)
            self._client_metadata.pop(queue, None)
            self._log(f"Connection unregistered. Total connections: {len(self.connections)}")
    
    async def broadcast(self, 
                        event_type: str,   # Type of event to broadcast
                        data: Dict[str, Any], # Data to broadcast
                        target_ids: Optional[List[str]] = None,  # Optional list of element IDs to target for updates
                        metadata: Optional[Dict[str, Any]] = None # Optional metadata for the message
                       ) -> int: # Number of clients successfully notified
        """
        Broadcast a message to all connected clients.            
        """
        message = BroadcastMessage(
            type=event_type,
            data=data,
            target_ids=target_ids,
            metadata=metadata
        )
        
        return await self.broadcast_message(message)
    
    async def run_message_handlers(
        self,
        message: BroadcastMessage  # Message to process
    ) -> BroadcastMessage:  # Potentially modified message
        """
        Run message through registered handlers.
        """
        for handler in self._message_handlers:
            result = await handler(message)
            if result is not None:
                message = result
        return message
    
    async def broadcast_message(
        self,
        message: BroadcastMessage  # Message to broadcast
    ) -> int:  # Number of clients successfully notified
        """
        Broadcast a pre-constructed message to all connected clients.            
        """
        # Run through handlers
        message = await self.run_message_handlers(message)
        
        # Store in history
        self.history.append(message)
        
        # Broadcast to all active connections
        async with self.lock:
            if not self.connections:
                self._log("No active connections to broadcast to")
                return 0
            
            disconnected = set()
            successful = 0
            
            for queue in self.connections:
                try:
                    # Non-blocking put with timeout
                    await asyncio.wait_for(
                        queue.put(message),
                        timeout=self.timeout
                    )
                    successful += 1
                except (asyncio.TimeoutError, asyncio.QueueFull):
                    self._log(f"Failed to send to client (timeout/full queue)")
                    disconnected.add(queue)
                except Exception as e:
                    self._log(f"Error broadcasting to client: {e}")
                    disconnected.add(queue)
            
            # Clean up disconnected clients
            for queue in disconnected:
                self.connections.discard(queue)
                self._client_metadata.pop(queue, None)
            
            if disconnected:
                self._log(f"Removed {len(disconnected)} disconnected clients")
            
            self._log(f"Broadcast to {successful}/{len(self.connections)} clients")
            return successful
    
    def add_message_handler(self, 
                            handler: Callable[[BroadcastMessage], Awaitable[Optional[BroadcastMessage]]] # Async function that takes and optionally returns a BroadcastMessage
                           ):
        """
        Add a message handler that can transform messages before broadcasting. Handlers are called in the order they are added and can modify messages or return None to pass through unchanged.
        """
        self._message_handlers.append(handler)
    
    def get_connection_count(
        self
    ) -> int:  # Current number of active connections
        """Get the current number of active connections"""
        return len(self.connections)
    
    def get_history(
        self,
        limit: Optional[int] = None  # Maximum number of messages to return (None for all)
    ) -> List[BroadcastMessage]:  # List of historical messages
        """
        Get message history.            
        """
        if limit is None:
            return list(self.history)
        return list(self.history)[-limit:]
    
    async def clear_history(self):
        """Clear the message history"""
        self.history.clear()
        self._log("Message history cleared")
    
    async def close_all_connections(self):
        """
        Close all active connections. Useful for cleanup during shutdown.
        """
        async with self.lock:
            self._log(f"Closing {len(self.connections)} connections")
            self.connections.clear()
            self._client_metadata.clear()

## Connection Context Manager

In [None]:
#| export
class SSEConnection:
    """
    Context manager for SSE connections. Automatically registers/unregisters connections with the broadcast manager.
    """
    
    def __init__(self, 
                 manager: SSEBroadcastManager,  # The SSEBroadcastManager to register with
                 metadata: Optional[Dict[str, Any]] = None # Optional metadata for this connection
                ):
        """
        Initialize SSE connection context.
        """
        self.manager = manager
        self.metadata = metadata
        self.queue: Optional[asyncio.Queue] = None
    
    async def __aenter__(
        self
    ) -> asyncio.Queue:  # The message queue for this connection
        """Register connection on entry"""
        self.queue = await self.manager.register_connection(metadata=self.metadata)
        return self.queue
    
    async def __aexit__(
        self,
        exc_type: Optional[type],  # Exception type if an exception occurred
        exc_val: Optional[Exception],  # Exception value if an exception occurred
        exc_tb: Optional[Any]  # Exception traceback if an exception occurred
    ):
        """Unregister connection on exit"""
        if self.queue:
            await self.manager.unregister_connection(self.queue)

## Example Usage

In [None]:
# Example: Basic broadcast manager usage
async def example_usage():
    # Create manager
    manager = SSEBroadcastManager(debug=True)
    
    # Register a connection using context manager
    async with SSEConnection(manager, metadata={"client_id": "test-1"}) as queue:
        # Broadcast a message
        await manager.broadcast(
            event_type="notification",
            data={"message": "Hello, world!"},
            target_ids=["notification-area"]
        )
        
        # Check if message was received
        try:
            message = await asyncio.wait_for(queue.get(), timeout=1.0)
            print(f"Received: {message.to_dict()}")
        except asyncio.TimeoutError:
            print("No message received")
    
    # Connection automatically unregistered
    print(f"Active connections: {manager.get_connection_count()}")

# Run example
await example_usage()

[SSEBroadcastManager] Connection registered. Total connections: 1
[SSEBroadcastManager] Broadcast to 1/1 clients
Received: {'type': 'notification', 'data': {'message': 'Hello, world!'}, 'timestamp': '2025-08-27T13:15:47.159952', 'target_ids': ['notification-area']}
[SSEBroadcastManager] Connection unregistered. Total connections: 0
Active connections: 0


In [None]:
# Example: Message handler for adding metadata
async def example_with_handler():
    manager = SSEBroadcastManager(debug=True)
    
    # Add a handler that adds server info to all messages
    async def add_server_info(message: BroadcastMessage) -> BroadcastMessage:
        if message.metadata is None:
            message.metadata = {}
        message.metadata["server_id"] = "server-1"
        message.metadata["processed_at"] = datetime.now().isoformat()
        return message
    
    manager.add_message_handler(add_server_info)
    
    # Register connection and broadcast
    async with SSEConnection(manager) as queue:
        await manager.broadcast("test", {"value": 42})
        
        message = await queue.get()
        print(f"Message metadata: {message.metadata}")

await example_with_handler()

[SSEBroadcastManager] Connection registered. Total connections: 1
[SSEBroadcastManager] Broadcast to 1/1 clients
Message metadata: {'server_id': 'server-1', 'processed_at': '2025-08-27T13:15:47.176562'}
[SSEBroadcastManager] Connection unregistered. Total connections: 0


In [None]:
# Example: Multiple connections and selective broadcasting
async def example_multiple_connections():
    manager = SSEBroadcastManager(debug=True)
    
    # Create multiple connections
    connections = []
    for i in range(3):
        queue = await manager.register_connection(
            metadata={"client_id": f"client-{i}"}
        )
        connections.append(queue)
    
    print(f"Registered {manager.get_connection_count()} connections")
    
    # Broadcast to all
    successful = await manager.broadcast(
        "announcement",
        {"message": "System update"},
        target_ids=["system-status"]
    )
    print(f"Broadcast to {successful} clients")
    
    # Check messages received
    for i, queue in enumerate(connections):
        try:
            message = await asyncio.wait_for(queue.get(), timeout=0.1)
            print(f"Client {i} received: {message.type}")
        except asyncio.TimeoutError:
            print(f"Client {i} timeout")
    
    # Clean up
    for queue in connections:
        await manager.unregister_connection(queue)
    
    print(f"Final connection count: {manager.get_connection_count()}")

await example_multiple_connections()

[SSEBroadcastManager] Connection registered. Total connections: 1
[SSEBroadcastManager] Connection registered. Total connections: 2
[SSEBroadcastManager] Connection registered. Total connections: 3
Registered 3 connections
[SSEBroadcastManager] Broadcast to 3/3 clients
Broadcast to 3 clients
Client 0 received: announcement
Client 1 received: announcement
Client 2 received: announcement
[SSEBroadcastManager] Connection unregistered. Total connections: 2
[SSEBroadcastManager] Connection unregistered. Total connections: 1
[SSEBroadcastManager] Connection unregistered. Total connections: 0
Final connection count: 0


## Testing

In [None]:
# Test: Basic functionality
async def test_basic_functionality():
    manager = SSEBroadcastManager()
    
    # Test connection registration
    queue = await manager.register_connection()
    assert manager.get_connection_count() == 1
    
    # Test broadcasting
    await manager.broadcast("test", {"data": "test"})
    message = await asyncio.wait_for(queue.get(), timeout=1.0)
    assert message.type == "test"
    assert message.data == {"data": "test"}
    
    # Test unregistration
    await manager.unregister_connection(queue)
    assert manager.get_connection_count() == 0
    
    print("✓ Basic functionality tests passed")

await test_basic_functionality()

✓ Basic functionality tests passed


In [None]:
# Test: History functionality
async def test_history():
    manager = SSEBroadcastManager(history_size=3)
    
    # Send multiple messages
    for i in range(5):
        await manager.broadcast("test", {"index": i})
    
    # Check history (should only have last 3)
    history = manager.get_history()
    assert len(history) == 3
    assert history[0].data["index"] == 2
    assert history[-1].data["index"] == 4
    
    # Test history limit
    limited = manager.get_history(limit=2)
    assert len(limited) == 2
    
    print("✓ History tests passed")

await test_history()

✓ History tests passed


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()