# Multi Stream

> Manage multiple related SSE streams for a single entity

In [None]:
#| default_exp core.multi_stream

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
from typing import Dict, Optional, Any, Callable, List, AsyncGenerator, Union
from dataclasses import dataclass, field
from fasthtml.common import EventStream, sse_message, Div, Span, FT
from cjm_fasthtml_sse.core.connections import ConnectionRegistry, SSEConnection
from cjm_fasthtml_sse.core.streaming import SSEStream, StreamConfig, OOBStreamBuilder

## Core Classes

In [None]:
#| export
@dataclass
class StreamEndpoint:
    """Configuration for a single SSE stream endpoint"""
    name: str  # Stream name (e.g., 'progress', 'status')
    path_suffix: str  # URL path suffix (e.g., '_progress', '_status')
    data_source: Callable  # Async generator or callable that produces data
    transform_fn: Optional[Callable] = None  # Optional transform function
    config: Optional[StreamConfig] = None  # Stream configuration
    metadata: Dict[str, Any] = field(default_factory=dict)  # Additional metadata

In [None]:
#| export
class MultiStreamManager:
    """Manages multiple related SSE streams for entities"""
    
    def __init__(
        self,
        base_path: str = "/stream",  # Base path for stream endpoints
        connection_registry: Optional[ConnectionRegistry] = None,  # Optional shared registry
        default_config: Optional[StreamConfig] = None,  # Default config for all streams
        debug: bool = False  # Enable debug logging
    ):
        """Initialize the multi-stream manager."""
        self.base_path = base_path
        self.connection_registry = connection_registry or ConnectionRegistry(debug=debug)
        self.default_config = default_config or StreamConfig()
        self.debug = debug
        self.endpoints: Dict[str, StreamEndpoint] = {}
    
    def register_endpoint(
        self,
        endpoint: StreamEndpoint  # Stream endpoint configuration
    ) -> 'MultiStreamManager':  # Self for chaining
        """Register a stream endpoint."""
        self.endpoints[endpoint.name] = endpoint
        if self.debug:
            print(f"[MultiStreamManager] Registered endpoint: {endpoint.name}")
        return self
    
    def create_element(
        self,
        entity_id: str,  # Entity ID (e.g., job_id, user_id)
        stream_name: str,  # Stream name to connect to
        element_id: Optional[str] = None,  # Optional element ID
        wrapper: Optional[Callable] = None,  # Optional wrapper function for the element
        **attrs  # Additional attributes
    ) -> FT:  # HTMX-enabled SSE element
        """Create an SSE-enabled element for a specific stream."""
        endpoint = self.endpoints.get(stream_name)
        if not endpoint:
            raise ValueError(f"Unknown stream: {stream_name}")
        
        path = f"{self.base_path}{endpoint.path_suffix}?id={entity_id}"
        
        sse_attrs = {
            'hx_ext': 'sse',
            'sse_connect': path,
            'sse_swap': 'message'
        }
        
        if element_id:
            sse_attrs['id'] = element_id
        
        sse_attrs.update(attrs)
        
        element = Span(**sse_attrs) if wrapper is None else wrapper(**sse_attrs)
        return element
    
    def create_handler(
        self,
        stream_name: str,  # Stream name
        get_entity_fn: Callable,  # Function to get entity state
        is_active_fn: Callable  # Function to check if entity is active
    ) -> Callable:  # Route handler function
        """Create a route handler for a specific stream."""
        endpoint = self.endpoints.get(stream_name)
        if not endpoint:
            raise ValueError(f"Unknown stream: {stream_name}")
        
        async def handler(id: str):  # Entity ID from query parameter
            """SSE handler for the stream."""
            async def stream():
                # Register connection
                connection = await self.connection_registry.add_connection(
                    conn_type=f"{stream_name}_{id}",
                    metadata={"entity_id": id, "stream": stream_name}
                )
                
                try:
                    # Check if entity is active
                    entity = get_entity_fn(id)
                    if not entity or not is_active_fn(entity):
                        # Send final message and close
                        if endpoint.transform_fn:
                            final_msg = endpoint.transform_fn(entity, final=True)
                            yield sse_message(final_msg)
                        return
                    
                    # Create data source
                    data_source = endpoint.data_source(id)
                    if callable(data_source) and not hasattr(data_source, '__anext__'):
                        data_source = data_source()
                    
                    # Stream with SSEStream wrapper
                    config = endpoint.config or self.default_config
                    stream_wrapper = SSEStream(config)
                    
                    async for message in stream_wrapper.stream(
                        data_source, 
                        lambda data: endpoint.transform_fn(data, entity_id=id) if endpoint.transform_fn else data
                    ):
                        yield message
                        
                finally:
                    # Unregister connection
                    await self.connection_registry.remove_connection(connection.connection_id)
            
            return EventStream(stream())
        
        return handler
    
    def setup_routes(
        self,
        app,  # FastHTML app
        get_entity_fn: Callable,  # Function to get entity state
        is_active_fn: Callable  # Function to check if entity is active
    ):
        """Setup all stream routes on the app."""
        for stream_name, endpoint in self.endpoints.items():
            path = f"{self.base_path}{endpoint.path_suffix}"
            handler = self.create_handler(stream_name, get_entity_fn, is_active_fn)
            
            # Register route
            app.route(path)(handler)
            
            if self.debug:
                print(f"[MultiStreamManager] Registered route: {path}")

### Tests for MultiStreamManager

In [None]:
#| test
# Test MultiStreamManager initialization
manager = MultiStreamManager(
    base_path="/api/stream",
    debug=False
)

#assert manager.base_path == "/api/stream"
#assert manager.debug == False
#assert len(manager.endpoints) == 0
#assert manager.connection_registry is not None
#assert manager.default_config is not None

print("✓ MultiStreamManager initialization tests passed")

✓ MultiStreamManager initialization tests passed


In [None]:
#| test
# Test endpoint registration
manager = MultiStreamManager()

# Create test endpoints
async def progress_source(entity_id):
    for i in range(3):
        yield {"progress": i * 33}

def progress_transform(data, **kwargs):
    return f"Progress: {data.get('progress', 0)}%"

progress_endpoint = StreamEndpoint(
    name="progress",
    path_suffix="_progress",
    data_source=progress_source,
    transform_fn=progress_transform
)

status_endpoint = StreamEndpoint(
    name="status",
    path_suffix="_status",
    data_source=lambda id: (yield {"status": "running"})
)

# Register endpoints
manager.register_endpoint(progress_endpoint)
manager.register_endpoint(status_endpoint)

#assert "progress" in manager.endpoints
#assert "status" in manager.endpoints
#assert manager.endpoints["progress"].path_suffix == "_progress"
#assert manager.endpoints["status"].path_suffix == "_status"

print("✓ MultiStreamManager endpoint registration tests passed")

✓ MultiStreamManager endpoint registration tests passed


In [None]:
#| test
# Test element creation
manager = MultiStreamManager(base_path="/sse")

# Register test endpoint
endpoint = StreamEndpoint(
    name="test",
    path_suffix="_test",
    data_source=lambda id: None
)
manager.register_endpoint(endpoint)

# Create element
element = manager.create_element(
    entity_id="entity123",
    stream_name="test",
    element_id="test-element"
)

#assert element.attrs.get('id') == "test-element"
#assert element.attrs.get('hx-ext') == "sse"
#assert element.attrs.get('sse-connect') == "/sse_test?id=entity123"
#assert element.attrs.get('sse-swap') == "message"

# Test with unknown stream
try:
    manager.create_element("entity123", "unknown")
    #assert False, "Should raise ValueError"
    pass
except ValueError as e:
    pass
    #assert "Unknown stream" in str(e)

print("✓ MultiStreamManager element creation tests passed")

✓ MultiStreamManager element creation tests passed


## Helper Functions

In [None]:
#| export
def create_multi_stream_row(
    manager: MultiStreamManager,  # Multi-stream manager
    entity_id: str,  # Entity ID
    streams: List[Dict[str, Any]],  # List of stream configurations
    row_builder: Callable,  # Function to build the row
    active: bool = True  # Whether streams should be active
) -> FT:  # Table row or similar element
    """Create a row element with multiple SSE streams."""
    stream_elements = {}
    
    for stream_config in streams:
        stream_name = stream_config['name']
        cell_id = stream_config.get('cell_id', f"{entity_id}-{stream_name}")
        content = stream_config.get('initial_content', '')
        
        if active and stream_config.get('active', True):
            # Create SSE-enabled element
            element = manager.create_element(
                entity_id=entity_id,
                stream_name=stream_name,
                element_id=cell_id,
                wrapper=stream_config.get('wrapper', Span)
            )
            # Add initial content if provided
            if content:
                element.children = [content] if not isinstance(content, list) else content
            stream_elements[stream_name] = element
        else:
            # Static content for inactive streams
            stream_elements[stream_name] = content
    
    return row_builder(entity_id, stream_elements)

In [None]:
#| test
# Test create_multi_stream_row
from fasthtml.common import Tr, Td

manager = MultiStreamManager()

# Register test endpoints
manager.register_endpoint(StreamEndpoint(
    name="progress",
    path_suffix="_progress",
    data_source=lambda id: None
))
manager.register_endpoint(StreamEndpoint(
    name="status",
    path_suffix="_status",
    data_source=lambda id: None
))

# Define row builder
def test_row_builder(entity_id, elements):
    return Tr(
        Td(entity_id),
        Td(elements.get('progress', 'N/A')),
        Td(elements.get('status', 'N/A')),
        id=f"row-{entity_id}"
    )

# Create row with active streams
row = create_multi_stream_row(
    manager=manager,
    entity_id="job123",
    streams=[
        {'name': 'progress', 'initial_content': 'Starting...'},
        {'name': 'status', 'initial_content': 'Pending'}
    ],
    row_builder=test_row_builder,
    active=True
)

#assert row.attrs.get('id') == "row-job123"
#assert len(row.children) == 3  # ID cell + 2 stream cells

# Test with inactive streams
static_row = create_multi_stream_row(
    manager=manager,
    entity_id="job456",
    streams=[
        {'name': 'progress', 'initial_content': '100%'},
        {'name': 'status', 'initial_content': 'Complete'}
    ],
    row_builder=test_row_builder,
    active=False
)

#assert static_row.attrs.get('id') == "row-job456"

print("✓ create_multi_stream_row tests passed")

✓ create_multi_stream_row tests passed


## Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()