# Streaming

> SSE streaming utilities and helpers

In [None]:
#| default_exp core.streaming

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
import json
from typing import AsyncGenerator, Callable, Optional, Any, Dict, List, Union
from datetime import datetime
from dataclasses import dataclass
from fasthtml.common import EventStream, sse_message, Div, FT

## Core Classes

In [None]:
#| export
@dataclass
class StreamConfig:
    """Configuration for SSE streaming"""
    heartbeat_interval: float = 30.0
    timeout: Optional[float] = None
    send_initial_message: bool = True
    initial_message: str = "Connected"
    send_close_message: bool = True
    close_message: str = "Connection closed"
    debug: bool = False

### Tests for StreamConfig

In [None]:
#| test
# Test StreamConfig creation and defaults
config = StreamConfig()

# Test default values
assert config.heartbeat_interval == 30.0
assert config.timeout is None
assert config.send_initial_message == True
assert config.initial_message == "Connected"
assert config.send_close_message == True
assert config.close_message == "Connection closed"
assert config.debug == False

# Test custom configuration
custom_config = StreamConfig(
    heartbeat_interval=15.0,
    timeout=60.0,
    send_initial_message=False,
    initial_message="Started",
    send_close_message=False,
    close_message="Ended",
    debug=True
)

assert custom_config.heartbeat_interval == 15.0
assert custom_config.timeout == 60.0
assert custom_config.send_initial_message == False
assert custom_config.initial_message == "Started"
assert custom_config.send_close_message == False
assert custom_config.close_message == "Ended"
assert custom_config.debug == True

print("✓ StreamConfig tests passed")

✓ StreamConfig tests passed


In [None]:
#| export
class SSEStream:
    """Generic SSE stream handler"""
    
    def __init__(
        self,
        config: Optional[StreamConfig] = None  # Stream configuration
    ):
        """Initialize the SSE stream."""
        self.config = config or StreamConfig()
        self._active = False
        self._message_count = 0
    
    async def stream(self,
                    data_source: Union[AsyncGenerator, Callable], # Async generator or callable that produces data
                    transform_fn: Optional[Callable] = None # Optional function to transform data before sending
                    ) -> AsyncGenerator[str, None]: # SSE formatted strings
        """Stream data from a source through SSE."""
        self._active = True
        
        try:
            # Send initial message
            if self.config.send_initial_message:
                yield f": {self.config.initial_message}\n\n"
            
            # Create async generator from callable if needed
            if callable(data_source) and not hasattr(data_source, '__anext__'):
                data_source = data_source()
            
            # Main streaming loop
            while self._active:
                try:
                    # Get data with heartbeat timeout
                    data = await asyncio.wait_for(
                        data_source.__anext__(),
                        timeout=self.config.heartbeat_interval
                    )
                    
                    # Transform data if function provided
                    if transform_fn:
                        data = await transform_fn(data) if asyncio.iscoroutinefunction(transform_fn) else transform_fn(data)
                    
                    # Send data
                    if data is not None:
                        self._message_count += 1
                        yield self._format_message(data)
                    
                except asyncio.TimeoutError:
                    # Send heartbeat
                    yield f": heartbeat {datetime.now().isoformat()}\n\n"
                    
                except StopAsyncIteration:
                    # Data source exhausted
                    break
                    
        except Exception as e:
            if self.config.debug:
                yield f": error {str(e)}\n\n"
            raise
            
        finally:
            self._active = False
            if self.config.send_close_message:
                yield f": {self.config.close_message}\n\n"
    
    def _format_message(
        self,
        data: Any  # Data to format
    ) -> str:  # SSE formatted string
        """Format data as SSE message."""
        if isinstance(data, str):
            return f"data: {data}\n\n"
        elif isinstance(data, dict):
            return f"data: {json.dumps(data)}\n\n"
        elif hasattr(data, '__html__'):
            # FastHTML element
            return sse_message(data)
        else:
            # Try to convert to JSON
            try:
                return f"data: {json.dumps(data)}\n\n"
            except (TypeError, ValueError):
                return f"data: {str(data)}\n\n"
    
    def stop(self):
        """Stop the stream."""
        self._active = False

### Tests for SSEStream

In [None]:
#| test
# Test SSEStream initialization
stream = SSEStream()
assert stream.config.heartbeat_interval == 30.0
assert stream._active == False
assert stream._message_count == 0

# Test with custom config
custom_config = StreamConfig(heartbeat_interval=10.0, debug=True)
stream_custom = SSEStream(config=custom_config)
assert stream_custom.config.heartbeat_interval == 10.0
assert stream_custom.config.debug == True

print("✓ SSEStream initialization tests passed")

✓ SSEStream initialization tests passed


In [None]:
#| test
# Test SSEStream streaming functionality
import asyncio

async def test_sse_stream():
    # Create a simple async generator
    async def data_generator():
        for i in range(3):
            yield f"message_{i}"
            await asyncio.sleep(0.01)
    
    # Test basic streaming
    config = StreamConfig(
        heartbeat_interval=5.0,
        send_initial_message=True,
        send_close_message=True
    )
    stream = SSEStream(config)
    
    messages = []
    async for msg in stream.stream(data_generator()):
        messages.append(msg)
    
    # Check messages
    assert len(messages) >= 5  # Initial + 3 data + close
    assert "Connected" in messages[0]
    assert "data: message_0" in messages[1]
    assert "data: message_1" in messages[2]
    assert "data: message_2" in messages[3]
    assert "Connection closed" in messages[-1]
    assert stream._message_count == 3
    
    print("✓ SSEStream streaming tests passed")

await test_sse_stream()

✓ SSEStream streaming tests passed


In [None]:
#| test
# Test SSEStream with transform function
async def test_sse_stream_transform():
    # Create data generator
    async def data_generator():
        for i in range(2):
            yield i
    
    # Define transform function
    def transform(x):
        return {"value": x * 2, "original": x}
    
    config = StreamConfig(
        send_initial_message=False,
        send_close_message=False
    )
    stream = SSEStream(config)
    
    messages = []
    async for msg in stream.stream(data_generator(), transform):
        messages.append(msg)
    
    # Check transformed messages
    assert len(messages) == 2
    # Messages should be formatted as SSE with the transformed dict data
    assert "value" in str(messages[0])
    assert "original" in str(messages[0])
    
    print("✓ SSEStream transform tests passed")

await test_sse_stream_transform()

✓ SSEStream transform tests passed


In [None]:
#| test
# Test SSEStream stop functionality
async def test_sse_stream_stop():
    stream = SSEStream()
    assert stream._active == False
    
    # The stop method should set _active to False
    stream._active = True
    stream.stop()
    assert stream._active == False
    
    print("✓ SSEStream stop tests passed")

await test_sse_stream_stop()

✓ SSEStream stop tests passed


In [None]:
#| export
class OOBStreamBuilder:
    """Build SSE messages with OOB (Out-of-Band) swaps"""
    
    def __init__(self):
        """Initialize the OOB stream builder."""
        self.elements: List[Any] = []
    
    def add_element(self,
                   element: Any,  # The element to add
                   target_id: Optional[str] = None,  # Target element ID for OOB swap
                   swap_mode: str = "innerHTML",  # Swap mode (innerHTML, outerHTML, beforeend, afterbegin, etc.)
                   wrap: bool = True  # If True and target_id is provided, wrap content in a Div with OOB attributes. If False, add OOB attributes directly to the element
                   ) -> 'OOBStreamBuilder':  # Self for chaining
        """Add an element with OOB swap configuration."""
        if target_id:
            if wrap and swap_mode == "innerHTML":
                # For innerHTML swaps, wrap the content in a container with the target ID
                # This is the most common case for replacing content
                wrapper = Div(element, id=target_id, hx_swap_oob="innerHTML")
                self.elements.append(wrapper)
            elif wrap and swap_mode != "outerHTML":
                # For other swap modes (beforeend, afterbegin, etc.), also wrap
                wrapper = Div(element, id=target_id, hx_swap_oob=swap_mode)
                self.elements.append(wrapper)
            else:
                # For outerHTML swaps or when wrap=False, add attributes directly to element
                if hasattr(element, 'attrs'):
                    # Only set ID if element doesn't already have one or if it matches target_id
                    if not element.attrs.get('id') or element.attrs.get('id') == target_id:
                        element.attrs['id'] = target_id
                    element.attrs['hx-swap-oob'] = swap_mode if swap_mode != "innerHTML" else "true"
                elif isinstance(element, dict):
                    if not element.get('id') or element.get('id') == target_id:
                        element['id'] = target_id
                    element['hx-swap-oob'] = swap_mode if swap_mode != "innerHTML" else "true"
                self.elements.append(element)
        else:
            # No target_id, just add the element as-is
            self.elements.append(element)
        
        return self
    
    def add_elements(
        self,
        elements: List[tuple]  # List of tuples: (element, target_id, swap_mode, wrap) or (element, target_id, swap_mode) or (element, target_id) or (element,)
    ) -> 'OOBStreamBuilder':  # Self for chaining
        """Add multiple elements with OOB configurations."""
        for item in elements:
            if len(item) == 4:
                element, target_id, swap_mode, wrap = item
            elif len(item) == 3:
                element, target_id, swap_mode = item
                wrap = True  # Default to wrapping
            elif len(item) == 2:
                element, target_id = item
                swap_mode = "innerHTML"
                wrap = True
            else:
                element = item[0] if isinstance(item, tuple) else item
                target_id = None
                swap_mode = "innerHTML"
                wrap = True
            
            self.add_element(element, target_id, swap_mode, wrap)
        
        return self
    
    def build(
        self
    ) -> FT:  # Div with all elements
        """Build the Div element with all elements."""
        if not self.elements:
            return ""
        
        if len(self.elements) == 1:
            return sse_message(self.elements[0])
        
        # Wrap multiple elements in a container
        return Div(*self.elements)
    
    def clear(
        self
    ) -> 'OOBStreamBuilder':  # Self for chaining
        """Clear all elements."""
        self.elements = []
        return self

### Tests for OOBStreamBuilder

In [None]:
#| test
# Test OOBStreamBuilder initialization and basic operations
builder = OOBStreamBuilder()
assert builder.elements == []

# Test add_element without target
builder.add_element("Hello World")
assert len(builder.elements) == 1
assert builder.elements[0] == "Hello World"

# Test clear
builder.clear()
assert len(builder.elements) == 0

print("✓ OOBStreamBuilder initialization tests passed")

✓ OOBStreamBuilder initialization tests passed


In [None]:
#| test
# Test OOBStreamBuilder with target_id and swap modes
from fasthtml.common import Div, P, Span

builder = OOBStreamBuilder()

# Test add_element with target_id and default swap mode (innerHTML)
content = P("Paragraph content")
builder.add_element(content, target_id="target1")
assert len(builder.elements) == 1
wrapper = builder.elements[0]
assert isinstance(wrapper, FT)
assert wrapper.attrs.get('id') == "target1"
assert wrapper.attrs.get('hx-swap-oob') == "innerHTML"

# Test with different swap mode
builder.clear()
builder.add_element(Span("Span content"), target_id="target2", swap_mode="beforeend")
assert len(builder.elements) == 1
wrapper = builder.elements[0]
assert wrapper.attrs.get('id') == "target2"
assert wrapper.attrs.get('hx-swap-oob') == "beforeend"

# Test with outerHTML swap mode (should not wrap)
builder.clear()
element = Div("Direct element", id="existing_id")
builder.add_element(element, target_id="existing_id", swap_mode="outerHTML", wrap=False)
assert len(builder.elements) == 1
assert builder.elements[0].attrs.get('id') == "existing_id"
assert builder.elements[0].attrs.get('hx-swap-oob') == "outerHTML"

print("✓ OOBStreamBuilder target_id tests passed")

✓ OOBStreamBuilder target_id tests passed


In [None]:
#| test
# Test OOBStreamBuilder add_elements
from fasthtml.common import Div, P, Span

builder = OOBStreamBuilder()

# Test add_elements with various tuple formats
elements_to_add = [
    (P("Element 1"), "id1", "innerHTML", True),
    (P("Element 2"), "id2", "beforeend"),
    (P("Element 3"), "id3"),
    (P("Element 4"),)
]

builder.add_elements(elements_to_add)
assert len(builder.elements) == 4

# Check first element (4-tuple)
assert builder.elements[0].attrs.get('id') == "id1"
assert builder.elements[0].attrs.get('hx-swap-oob') == "innerHTML"

# Check second element (3-tuple)
assert builder.elements[1].attrs.get('id') == "id2"
assert builder.elements[1].attrs.get('hx-swap-oob') == "beforeend"

# Check third element (2-tuple)
assert builder.elements[2].attrs.get('id') == "id3"

# Check fourth element (single element)
# Should be added as-is without wrapping
assert isinstance(builder.elements[3], FT)

print("✓ OOBStreamBuilder add_elements tests passed")

✓ OOBStreamBuilder add_elements tests passed


In [None]:
#| test
# Test OOBStreamBuilder build method
from fasthtml.common import Div, P

builder = OOBStreamBuilder()

# Test build with no elements
result = builder.build()
assert result == ""

# Test build with single element
builder.add_element(P("Single"))
result = builder.build()
# Single element should be wrapped in sse_message
assert result is not None

# Test build with multiple elements
builder.clear()
builder.add_element(P("First"))
builder.add_element(P("Second"))
result = builder.build()
assert isinstance(result, FT)
assert len(result.children) == 2

# Test method chaining
builder2 = OOBStreamBuilder()
result = (builder2
    .add_element(P("One"))
    .add_element(P("Two"))
    .clear()
    .add_element(P("Three")))
assert len(builder2.elements) == 1

print("✓ OOBStreamBuilder build tests passed")

✓ OOBStreamBuilder build tests passed


## Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()