# Broadcast

> Broadcasting infrastructure for SSE cross-tab synchronization

In [None]:
#| default_exp core.broadcast

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
from typing import Dict, Set, Any, Optional, Callable, Deque
from collections import deque
from datetime import datetime
import json
from dataclasses import dataclass, field, asdict
from fasthtml.common import EventStream, sse_message

## Core Classes

In [None]:
#| export
@dataclass
class BroadcastMessage:
    """Standard broadcast message format for SSE communication"""
    type: str
    data: Dict[str, Any]
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(
        self
    ) -> Dict[str, Any]:  # TODO: Add return description
        """Convert message to dictionary format"""
        return asdict(self)
    
    def to_json(
        self
    ) -> str:  # TODO: Add return description
        """Convert message to JSON string"""
        return json.dumps(self.to_dict())
    
    def to_sse(
        self,
        event_type: Optional[str] = None  # TODO: Add description
    ) -> str:  # TODO: Add return description
        """Convert to SSE message format using FastHTML's sse_message"""
        if event_type:
            return f"event: {event_type}\ndata: {self.to_json()}\n\n"
        return sse_message(self.to_dict())

In [None]:
#| export
class BroadcastManager:
    """Manages SSE broadcast connections across multiple tabs/clients"""
    
    def __init__(self, 
                 max_queue_size: int = 100,  # TODO: Add description
                 history_limit: int = 50,  # TODO: Add description
                 queue_timeout: float = 0.1,  # TODO: Add description
                 debug: bool = False):
        """
        Initialize the broadcast manager.
        
        Args:
            max_queue_size: Maximum size for each connection's message queue
            history_limit: Number of recent messages to keep in history
            queue_timeout: Timeout for queue operations in seconds
            debug: Enable debug logging
        """
        self.connections: Dict[str, asyncio.Queue] = {}
        self.connection_metadata: Dict[str, Dict[str, Any]] = {}
        self.history: Deque[BroadcastMessage] = deque(maxlen=history_limit)
        self.lock = asyncio.Lock()
        self.max_queue_size = max_queue_size
        self.queue_timeout = queue_timeout
        self.debug = debug
        self._connection_counter = 0
    
    async def register(self, 
                      connection_id: Optional[str] = None,  # TODO: Add description
                      metadata: Optional[Dict[str, Any]] = None) -> tuple[str, asyncio.Queue]:
        """
        Register a new connection and return its queue.
        
        Args:
            connection_id: Optional ID for the connection (auto-generated if not provided)
            metadata: Optional metadata for the connection
            
        Returns:
            Tuple of (connection_id, queue)
        """
        queue = asyncio.Queue(maxsize=self.max_queue_size)
        
        async with self.lock:
            if connection_id is None:
                self._connection_counter += 1
                connection_id = f"conn_{self._connection_counter}"
            
            self.connections[connection_id] = queue
            self.connection_metadata[connection_id] = metadata or {}
            self.connection_metadata[connection_id]['connected_at'] = datetime.now().isoformat()
            
            if self.debug:
                print(f"[BroadcastManager] Registered connection: {connection_id}")
        
        return connection_id, queue
    
    async def unregister(
        self,
        connection_id: str  # TODO: Add description
    ):
        """Unregister a connection."""
        async with self.lock:
            if connection_id in self.connections:
                del self.connections[connection_id]
                del self.connection_metadata[connection_id]
                
                if self.debug:
                    print(f"[BroadcastManager] Unregistered connection: {connection_id}")
    
    async def broadcast(self, 
                       message_type: str,  # TODO: Add description
                       data: Dict[str, Any],
                       metadata: Optional[Dict[str, Any]] = None,
                       exclude: Optional[Set[str]] = None) -> int:
        """
        Broadcast a message to all connected clients.
        
        Args:
            message_type: Type of the message
            data: Message data
            metadata: Optional metadata
            exclude: Set of connection IDs to exclude from broadcast
            
        Returns:
            Number of successful broadcasts
        """
        message = BroadcastMessage(type=message_type, data=data, metadata=metadata)
        self.history.append(message)
        
        exclude = exclude or set()
        successful = 0
        
        async with self.lock:
            disconnected = set()
            
            for conn_id, queue in self.connections.items():
                if conn_id in exclude:
                    continue
                    
                try:
                    await asyncio.wait_for(
                        queue.put(message),
                        timeout=self.queue_timeout
                    )
                    successful += 1
                except (asyncio.TimeoutError, asyncio.QueueFull):
                    disconnected.add(conn_id)
                    if self.debug:
                        print(f"[BroadcastManager] Failed to send to {conn_id}")
            
            # Clean up disconnected clients
            for conn_id in disconnected:
                del self.connections[conn_id]
                del self.connection_metadata[conn_id]
        
        if self.debug:
            print(f"[BroadcastManager] Broadcast to {successful} clients")
        
        return successful
    
    async def send_to(self,
                     connection_id: str,  # TODO: Add description
                     message_type: str,  # TODO: Add description
                     data: Dict[str, Any],
                     metadata: Optional[Dict[str, Any]] = None) -> bool:
        """
        Send a message to a specific connection.
        
        Args:
            connection_id: Target connection ID
            message_type: Type of the message
            data: Message data
            metadata: Optional metadata
            
        Returns:
            True if successful, False otherwise
        """
        if connection_id not in self.connections:
            return False
        
        message = BroadcastMessage(type=message_type, data=data, metadata=metadata)
        
        try:
            await asyncio.wait_for(
                self.connections[connection_id].put(message),
                timeout=self.queue_timeout
            )
            return True
        except (asyncio.TimeoutError, asyncio.QueueFull):
            await self.unregister(connection_id)
            return False
    
    def get_connection_count(
        self
    ) -> int:  # TODO: Add return description
        """Get the number of active connections."""
        return len(self.connections)
    
    def get_history(
        self,
        limit: Optional[int] = None  # TODO: Add description
    ) -> list[BroadcastMessage]:  # TODO: Add return description
        """Get broadcast history."""
        if limit:
            return list(self.history)[-limit:]
        return list(self.history)

## Helper Functions

In [None]:
#| export
async def create_broadcast_endpoint(manager: BroadcastManager,
                                   connection_id: Optional[str] = None,  # Optional connection ID
                                   heartbeat_interval: float = 30.0,  # Interval for sending heartbeat messages
                                   send_history: bool = False,  # Whether to send recent history on connection
                                   history_limit: int = 10) -> EventStream:
    "Create an SSE endpoint for broadcasting."
    async def stream():
        "TODO: Add function description"
        conn_id, queue = await manager.register(connection_id)
        
        try:
            # Send connection confirmation
            yield f": Connected as {conn_id} (active: {manager.get_connection_count()})\n\n"
            
            # Send history if requested
            if send_history:
                history = manager.get_history(history_limit)
                for msg in history:
                    yield f"event: history\ndata: {msg.to_json()}\n\n"
            
            # Main message loop
            while True:
                try:
                    # Wait for message with timeout for heartbeat
                    message = await asyncio.wait_for(
                        queue.get(),
                        timeout=heartbeat_interval
                    )
                    
                    # Send the message
                    yield message.to_sse()
                    
                except asyncio.TimeoutError:
                    # Send heartbeat
                    yield f": heartbeat {datetime.now().isoformat()}\n\n"
                    
        except asyncio.CancelledError:
            pass
        finally:
            await manager.unregister(conn_id)
    
    return EventStream(stream())

In [None]:
#| export
def create_broadcast_handler(manager: BroadcastManager,
                            element_builder: Optional[Callable] = None):
    "Create a broadcast handler function that can be used with FastHTML routes."
    async def handler(
        connection_id: Optional[str] = None  # TODO: Add description
    ):
        "TODO: Add function description"
        async def stream():
            "TODO: Add function description"
            conn_id, queue = await manager.register(connection_id)
            
            try:
                yield f": Connected ({manager.get_connection_count()} active)\n\n"
                
                while True:
                    try:
                        message = await asyncio.wait_for(queue.get(), timeout=30.0)
                        
                        if element_builder:
                            # Use custom element builder if provided
                            elements = element_builder(message.type, message.data)
                            yield sse_message(elements)
                        else:
                            # Default to sending raw message
                            yield message.to_sse()
                            
                    except asyncio.TimeoutError:
                        yield f": heartbeat\n\n"
                        
            finally:
                await manager.unregister(conn_id)
        
        return EventStream(stream())
    
    return handler

In [None]:
#| export 
def setup_broadcast_routes(app, 
                          manager: BroadcastManager,  # The broadcast manager instance
                          prefix: str = "/sse",  # URL prefix for SSE endpoints
                          element_builder: Optional[Callable] = None):
    "Setup broadcast routes on a FastHTML app."
    handler = create_broadcast_handler(manager, element_builder)
    
    # Register the route
    @app.route(f"{prefix}/broadcast")
    async def broadcast_endpoint(
        connection_id: Optional[str] = None  # TODO: Add description
    ):
        "TODO: Add function description"
        return await handler(connection_id)
    
    # Add a status endpoint
    @app.route(f"{prefix}/status")
    def broadcast_status():
        return {
            "connections": manager.get_connection_count(),
            "history_size": len(manager.history)
        }

## Examples

### Basic Usage

In [None]:
# Example: Basic broadcast setup
from fasthtml.common import *

# Create manager
broadcast_mgr = BroadcastManager(debug=True)

# Example of broadcasting an update
async def notify_all_clients():
    await broadcast_mgr.broadcast(
        message_type="notification",
        data={"message": "System update completed"},
        metadata={"priority": "high"}
    )

### Custom Element Builder

In [None]:
# Example: Custom element builder for OOB swaps
def build_oob_elements(message_type: str, data: Dict[str, Any]):
    """Build HTML elements with OOB swaps based on message type."""
    from fasthtml.common import Div, Span
    
    elements = []
    
    if message_type == "status_update":
        elements.append(
            Div(
                data.get("status", "Unknown"),
                id="status-display",
                hx_swap_oob="true"
            )
        )
    
    elif message_type == "counter_update":
        elements.append(
            Span(
                str(data.get("count", 0)),
                id="counter",
                hx_swap_oob="innerHTML"
            )
        )
    
    return Div(*elements) if elements else None

### Integration with FastHTML App

In [None]:
# Example: Complete integration
def create_app_with_broadcast():
    app, rt = fast_app()
    manager = BroadcastManager()
    
    # Setup routes
    setup_broadcast_routes(app, manager, element_builder=build_oob_elements)
    
    @rt("/")
    def index():
        return Div(
            H1("Broadcast Demo"),
            Div(
                id="global-sse",
                hx_ext="sse",
                sse_connect="/sse/broadcast",
                sse_swap="message"
            ),
            Div("Waiting for updates...", id="status-display"),
            Span("0", id="counter")
        )
    
    @rt("/trigger")
    async def trigger():
        await manager.broadcast(
            "counter_update",
            {"count": 42}
        )
        return "Broadcast sent!"
    
    return app, manager

## Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()