# Broadcast

> Broadcasting infrastructure for SSE cross-tab synchronization

In [None]:
#| default_exp core.broadcast

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
from typing import Dict, Set, Any, Optional, Callable, Deque
from collections import deque
from datetime import datetime
import json
from dataclasses import dataclass, field, asdict
from fasthtml.common import EventStream, sse_message, FT

## Core Classes

In [None]:
#| export
@dataclass
class BroadcastMessage:
    """Standard broadcast message format for SSE communication"""
    type: str
    data: Dict[str, Any]
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(
        self
    ) -> Dict[str, Any]:  # Dictionary representation of the message
        """Convert message to dictionary format"""
        return asdict(self)
    
    def to_json(
        self
    ) -> str:  # JSON string representation of the message
        """Convert message to JSON string"""
        return json.dumps(self.to_dict())
    
    def to_sse(
        self,
        event_type: Optional[str] = 'message'  # SSE event type for the message
    ) -> str:  # SSE formatted message string
        """Convert to SSE message format using FastHTML's sse_message"""
        if isinstance(self.data, FT):
            return sse_message(self.data, event=event_type)
        else:
            return f"event: {event_type}\ndata: {self.to_json()}\n\n"

### Tests for BroadcastMessage

In [None]:
#| test
# Test BroadcastMessage creation and methods
msg = BroadcastMessage(
    type="test",
    data={"content": "Hello World", "count": 42}
)

# Test basic attributes
assert msg.type == "test"
assert msg.data["content"] == "Hello World"
assert msg.data["count"] == 42
assert msg.timestamp is not None
assert msg.metadata is None

# Test with metadata
msg_with_meta = BroadcastMessage(
    type="notification",
    data={"message": "Test"},
    metadata={"sender": "system", "priority": "high"}
)
assert msg_with_meta.metadata["sender"] == "system"
assert msg_with_meta.metadata["priority"] == "high"

print("✓ BroadcastMessage creation tests passed")

✓ BroadcastMessage creation tests passed


In [None]:
#| test
# Test BroadcastMessage conversion methods
msg = BroadcastMessage(
    type="update",
    data={"status": "active", "value": 100}
)

# Test to_dict
msg_dict = msg.to_dict()
assert isinstance(msg_dict, dict)
assert msg_dict["type"] == "update"
assert msg_dict["data"]["status"] == "active"
assert "timestamp" in msg_dict

# Test to_json
msg_json = msg.to_json()
assert isinstance(msg_json, str)
import json
parsed = json.loads(msg_json)
assert parsed["type"] == "update"
assert parsed["data"]["value"] == 100

# Test to_sse
msg_sse = msg.to_sse()
assert isinstance(msg_sse, str)
assert "event: message" in msg_sse
assert "data:" in msg_sse

# Test to_sse with custom event type
msg_sse_custom = msg.to_sse(event_type="custom_event")
assert "event: custom_event" in msg_sse_custom

print("✓ BroadcastMessage conversion methods tests passed")

✓ BroadcastMessage conversion methods tests passed


In [None]:
#| export
class BroadcastManager:
    """Manages SSE broadcast connections across multiple tabs/clients"""
    
    def __init__(self, 
                 max_queue_size: int = 100,  # Maximum size for each connection's message queue
                 history_limit: int = 50,  # Number of recent messages to keep in history
                 queue_timeout: float = 0.1,  # Timeout for queue operations in seconds
                 debug: bool = False # Enable debug logging
                ):
        """
        Initialize the broadcast manager.
        """
        self.connections: Dict[str, asyncio.Queue] = {}
        self.connection_metadata: Dict[str, Dict[str, Any]] = {}
        self.history: Deque[BroadcastMessage] = deque(maxlen=history_limit)
        self.lock = asyncio.Lock()
        self.max_queue_size = max_queue_size
        self.queue_timeout = queue_timeout
        self.debug = debug
        self._connection_counter = 0
    
    async def register(self, 
                      connection_id: Optional[str] = None,  # Optional ID for the connection (auto-generated if not provided)
                      metadata: Optional[Dict[str, Any]] = None # Optional metadata for the connection
                      ) -> tuple[str, asyncio.Queue]: # Tuple of (connection_id, queue)
        """
        Register a new connection and return its queue.
        """
        queue = asyncio.Queue(maxsize=self.max_queue_size)
        
        async with self.lock:
            if connection_id is None:
                self._connection_counter += 1
                connection_id = f"conn_{self._connection_counter}"
            
            self.connections[connection_id] = queue
            self.connection_metadata[connection_id] = metadata or {}
            self.connection_metadata[connection_id]['connected_at'] = datetime.now().isoformat()
            
            if self.debug:
                print(f"[BroadcastManager] Registered connection: {connection_id}")
        
        return connection_id, queue
    
    async def unregister(
        self,
        connection_id: str  # ID of the connection to unregister
    ):
        """Unregister a connection."""
        async with self.lock:
            if connection_id in self.connections:
                del self.connections[connection_id]
                del self.connection_metadata[connection_id]
                
                if self.debug:
                    print(f"[BroadcastManager] Unregistered connection: {connection_id}")
    
    async def broadcast(self, 
                       message_type: str,  # Type of the message
                       data: Dict[str, Any], # Message data
                       metadata: Optional[Dict[str, Any]] = None, # Optional metadata
                       exclude: Optional[Set[str]] = None # Set of connection IDs to exclude from broadcast
                       ) -> int: # Number of successful broadcasts
        """
        Broadcast a message to all connected clients.
        """
        message = BroadcastMessage(type=message_type, data=data, metadata=metadata)
        self.history.append(message)
        
        exclude = exclude or set()
        successful = 0
        
        async with self.lock:
            disconnected = set()
            
            for conn_id, queue in self.connections.items():
                if conn_id in exclude:
                    continue
                    
                try:
                    await asyncio.wait_for(
                        queue.put(message),
                        timeout=self.queue_timeout
                    )
                    successful += 1
                except (asyncio.TimeoutError, asyncio.QueueFull):
                    disconnected.add(conn_id)
                    if self.debug:
                        print(f"[BroadcastManager] Failed to send to {conn_id}")
            
            # Clean up disconnected clients
            for conn_id in disconnected:
                del self.connections[conn_id]
                del self.connection_metadata[conn_id]
        
        if self.debug:
            print(f"[BroadcastManager] Broadcast to {successful} clients")
        
        return successful
    
    async def send_to(self,
                     connection_id: str,  # Target connection ID
                     message_type: str,  # Type of the message
                     data: Dict[str, Any], # Message data
                     metadata: Optional[Dict[str, Any]] = None # Optional metadata
                     ) -> bool: # True if successful, False otherwise
        """
        Send a message to a specific connection.
        """
        if connection_id not in self.connections:
            return False
        
        message = BroadcastMessage(type=message_type, data=data, metadata=metadata)
        
        try:
            await asyncio.wait_for(
                self.connections[connection_id].put(message),
                timeout=self.queue_timeout
            )
            return True
        except (asyncio.TimeoutError, asyncio.QueueFull):
            await self.unregister(connection_id)
            return False
    
    def get_connection_count(
        self
    ) -> int:  # Number of active connections
        """Get the number of active connections."""
        return len(self.connections)
    
    def get_history(
        self,
        limit: Optional[int] = None  # Maximum number of messages to return
    ) -> list[BroadcastMessage]:  # List of broadcast messages from history
        """Get broadcast history."""
        if limit:
            return list(self.history)[-limit:]
        return list(self.history)

### Tests for BroadcastManager

In [None]:
#| test
# Test BroadcastManager initialization
manager = BroadcastManager(
    max_queue_size=50,
    history_limit=20,
    queue_timeout=0.5,
    debug=False
)

assert manager.max_queue_size == 50
assert manager.queue_timeout == 0.5
assert manager.debug == False
assert len(manager.connections) == 0
assert len(manager.history) == 0
assert manager.get_connection_count() == 0

print("✓ BroadcastManager initialization tests passed")

✓ BroadcastManager initialization tests passed


In [None]:
#| test
# Test async operations with BroadcastManager
import asyncio

async def test_registration():
    manager = BroadcastManager()
    
    # Test registration without ID
    conn_id1, queue1 = await manager.register()
    assert conn_id1 == "conn_1"
    assert isinstance(queue1, asyncio.Queue)
    assert manager.get_connection_count() == 1
    
    # Test registration with custom ID
    conn_id2, queue2 = await manager.register(connection_id="custom_id")
    assert conn_id2 == "custom_id"
    assert manager.get_connection_count() == 2
    
    # Test registration with metadata
    conn_id3, queue3 = await manager.register(
        connection_id="meta_conn",
        metadata={"user": "test_user", "role": "admin"}
    )
    assert manager.connection_metadata[conn_id3]["user"] == "test_user"
    assert manager.connection_metadata[conn_id3]["role"] == "admin"
    assert "connected_at" in manager.connection_metadata[conn_id3]
    
    # Test unregistration
    await manager.unregister(conn_id1)
    assert manager.get_connection_count() == 2
    assert conn_id1 not in manager.connections
    
    await manager.unregister(conn_id2)
    await manager.unregister(conn_id3)
    assert manager.get_connection_count() == 0
    
    print("✓ BroadcastManager registration tests passed")

# Run the async test
await test_registration()

✓ BroadcastManager registration tests passed


In [None]:
#| test
# Test broadcasting functionality
async def test_broadcasting():
    manager = BroadcastManager(history_limit=5)
    
    # Register multiple connections
    conn_id1, queue1 = await manager.register("conn1")
    conn_id2, queue2 = await manager.register("conn2")
    conn_id3, queue3 = await manager.register("conn3")
    
    # Test broadcast to all
    count = await manager.broadcast(
        message_type="notification",
        data={"text": "Hello all!"},
        metadata={"broadcast": True}
    )
    assert count == 3
    
    # Check that all queues received the message
    msg1 = await asyncio.wait_for(queue1.get(), timeout=1.0)
    msg2 = await asyncio.wait_for(queue2.get(), timeout=1.0)
    msg3 = await asyncio.wait_for(queue3.get(), timeout=1.0)
    
    assert msg1.type == "notification"
    assert msg1.data["text"] == "Hello all!"
    assert msg2.type == "notification"
    assert msg3.type == "notification"
    
    # Test broadcast with exclusion
    count = await manager.broadcast(
        message_type="update",
        data={"status": "active"},
        exclude={"conn2"}
    )
    assert count == 2
    
    msg1 = await asyncio.wait_for(queue1.get(), timeout=1.0)
    msg3 = await asyncio.wait_for(queue3.get(), timeout=1.0)
    assert msg1.type == "update"
    assert msg3.type == "update"
    
    # Queue2 should be empty
    assert queue2.empty()
    
    # Check history
    history = manager.get_history()
    assert len(history) == 2
    assert history[0].type == "notification"
    assert history[1].type == "update"
    
    # Clean up
    await manager.unregister("conn1")
    await manager.unregister("conn2")
    await manager.unregister("conn3")
    
    print("✓ BroadcastManager broadcasting tests passed")

await test_broadcasting()

✓ BroadcastManager broadcasting tests passed


In [None]:
#| test
# Test send_to functionality
async def test_send_to():
    manager = BroadcastManager()
    
    # Register connections
    conn_id1, queue1 = await manager.register("target_conn")
    conn_id2, queue2 = await manager.register("other_conn")
    
    # Send to specific connection
    success = await manager.send_to(
        connection_id="target_conn",
        message_type="direct_message",
        data={"content": "Hello target!"},
        metadata={"private": True}
    )
    assert success == True
    
    # Check that only target connection received the message
    msg1 = await asyncio.wait_for(queue1.get(), timeout=1.0)
    assert msg1.type == "direct_message"
    assert msg1.data["content"] == "Hello target!"
    assert msg1.metadata["private"] == True
    
    # Other connection should not receive it
    assert queue2.empty()
    
    # Test sending to non-existent connection
    success = await manager.send_to(
        connection_id="non_existent",
        message_type="test",
        data={}
    )
    assert success == False
    
    # Clean up
    await manager.unregister("target_conn")
    await manager.unregister("other_conn")
    
    print("✓ BroadcastManager send_to tests passed")

await test_send_to()

✓ BroadcastManager send_to tests passed


In [None]:
#| test
# Test history functionality
async def test_history():
    manager = BroadcastManager(history_limit=3)
    
    # Register a connection
    conn_id, queue = await manager.register()
    
    # Send multiple broadcasts
    for i in range(5):
        await manager.broadcast(
            message_type=f"msg_{i}",
            data={"index": i}
        )
    
    # Check history limit is respected
    history = manager.get_history()
    assert len(history) == 3  # Only last 3 messages
    assert history[0].type == "msg_2"
    assert history[1].type == "msg_3"
    assert history[2].type == "msg_4"
    
    # Test get_history with limit
    limited_history = manager.get_history(limit=2)
    assert len(limited_history) == 2
    assert limited_history[0].type == "msg_3"
    assert limited_history[1].type == "msg_4"
    
    # Clean up
    await manager.unregister(conn_id)
    
    print("✓ BroadcastManager history tests passed")

await test_history()

✓ BroadcastManager history tests passed


## Helper Functions

In [None]:
#| export
def create_broadcast_handler(manager: BroadcastManager,
                            element_builder: Optional[Callable] = None):
    """Create a broadcast handler function that can be used with FastHTML routes."""
    async def handler(
        connection_id: Optional[str] = None  # Optional connection ID
    ):
        """Handle SSE broadcast connection."""
        async def stream():
            """Generate SSE stream for the connection."""
            conn_id, queue = await manager.register(connection_id)
            
            try:
                yield f": Connected ({manager.get_connection_count()} active)\n\n"
                
                while True:
                    try:
                        message = await asyncio.wait_for(queue.get(), timeout=30.0)
                        
                        if element_builder:
                            # Use custom element builder if provided
                            elements = element_builder(message.type, message.data)
                            yield sse_message(elements)
                        else:
                            # Default to sending raw message
                            yield message.to_sse()
                            
                    except asyncio.TimeoutError:
                        yield f": heartbeat\n\n"
                        
            finally:
                await manager.unregister(conn_id)
        
        return EventStream(stream())
    
    return handler

In [None]:
#| test
# Test create_broadcast_handler
def test_create_broadcast_handler():
    manager = BroadcastManager()
    
    # Test without element builder
    handler = create_broadcast_handler(manager)
    assert callable(handler)
    
    # Test with custom element builder
    def custom_builder(msg_type, data):
        return f"<div>Type: {msg_type}, Data: {data}</div>"
    
    handler_with_builder = create_broadcast_handler(manager, element_builder=custom_builder)
    assert callable(handler_with_builder)
    
    print("✓ create_broadcast_handler tests passed")

test_create_broadcast_handler()

✓ create_broadcast_handler tests passed


In [None]:
#| export 
def setup_broadcast_routes(app, 
                          manager: BroadcastManager,  # The broadcast manager instance
                          prefix: str = "/sse",  # URL prefix for SSE endpoints
                          element_builder: Optional[Callable] = None):
    """Setup broadcast routes on a FastHTML app."""
    handler = create_broadcast_handler(manager, element_builder)
    
    # Register the route
    @app.route(f"{prefix}/broadcast")
    async def broadcast_endpoint(
        connection_id: Optional[str] = None  # Optional connection ID for the SSE connection
    ):
        """SSE broadcast endpoint for client connections."""
        return await handler(connection_id)
    
    # Add a status endpoint
    @app.route(f"{prefix}/status")
    def broadcast_status():
        return {
            "connections": manager.get_connection_count(),
            "history_size": len(manager.history)
        }

In [None]:
#| test
# Test setup_broadcast_routes
from fasthtml.common import FastHTML

def test_setup_broadcast_routes():
    # Create a minimal FastHTML app for testing
    app = FastHTML()
    manager = BroadcastManager()
    
    # Setup routes with default prefix
    setup_broadcast_routes(app, manager)
    
    # Check that routes were added
    # Note: FastHTML stores routes internally, we can verify by checking the app object
    assert app is not None
    
    # Setup routes with custom prefix and element builder
    def custom_builder(msg_type, data):
        return f"Custom: {msg_type}"
    
    app2 = FastHTML()
    setup_broadcast_routes(
        app2, 
        manager, 
        prefix="/custom_sse",
        element_builder=custom_builder
    )
    
    assert app2 is not None
    
    print("✓ setup_broadcast_routes tests passed")

test_setup_broadcast_routes()

✓ setup_broadcast_routes tests passed


## Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()