# decorators

> Route decorators and utilities for FastHTML SSE endpoints

In [None]:
#| default_exp core.decorators

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import asyncio
import functools
from typing import Optional, Callable, Any, Dict, Union, AsyncIterator, AsyncGenerator
from inspect import signature, iscoroutinefunction

# Import EventStream from FastHTML (it's a function that creates StreamingResponse)
try:
    from fasthtml.common import EventStream
    from starlette.responses import StreamingResponse
except ImportError:
    # Fallback if FastHTML not installed
    from starlette.responses import StreamingResponse
    def EventStream(
        s: AsyncGenerator  # Async generator that yields SSE formatted strings
    ):
        "Create a text/event-stream response from `s`"
        return StreamingResponse(s, media_type="text/event-stream")

# Import broadcast and streaming components
from cjm_fasthtml_sse.core.broadcast import (
    SSEBroadcastManager,
    BroadcastMessage
)
from cjm_fasthtml_sse.core.streaming import (
    sse_broadcast_stream,
    SSEStreamConfig,
    format_sse_message,
    sse_generator
)

## SSE Endpoint Decorator

In [None]:
#| export
def sse_endpoint(
    broadcast_manager: Optional[SSEBroadcastManager] = None,  # Optional SSEBroadcastManager for automatic broadcasting
    config: Optional[SSEStreamConfig] = None,  # Optional SSEStreamConfig for stream configuration
    message_filter: Optional[Callable[[BroadcastMessage], bool]] = None,
    client_metadata_fn: Optional[Callable[..., Dict[str, Any]]] = None
):
    """Decorator for creating SSE endpoints with optional broadcast integration. This decorator can work in two modes: 1. With broadcast_manager: Automatically connects to broadcast stream 2. Without broadcast_manager: Wraps custom async generator functions"""
    def decorator(
        func: Callable  # The endpoint function to decorate
    ):
        """Apply SSE endpoint decoration to a function"""
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            """Wrapped SSE endpoint handler"""
            # Extract client metadata if function provided
            client_metadata = None
            if client_metadata_fn:
                client_metadata = client_metadata_fn(*args, **kwargs)
            
            # Mode 1: Use broadcast manager
            if broadcast_manager:
                # Call the original function to allow for any setup
                if iscoroutinefunction(func):
                    result = await func(*args, **kwargs)
                else:
                    result = func(*args, **kwargs)
                
                # If function returns a filter, use it
                filter_fn = message_filter
                if callable(result):
                    filter_fn = result
                
                # Create broadcast stream
                generator = sse_broadcast_stream(
                    broadcast_manager,
                    client_metadata=client_metadata,
                    config=config,
                    message_filter=filter_fn
                )
                
                return EventStream(generator)
            
            # Mode 2: Wrap custom generator
            else:
                # Get the generator from the function
                if iscoroutinefunction(func):
                    gen = await func(*args, **kwargs)
                else:
                    gen = func(*args, **kwargs)
                
                # If it's already an async generator, format messages
                async def format_generator():
                    """Format and yield SSE messages from generator"""
                    if hasattr(gen, '__aiter__'):
                        async for data in gen:
                            if isinstance(data, str):
                                yield data  # Already formatted
                            else:
                                yield format_sse_message(data)
                    else:
                        # Not a generator, create one
                        yield format_sse_message(gen)
                
                return EventStream(format_generator())
        
        return wrapper
    return decorator

## Broadcast Action Decorator

In [None]:
#| export
def broadcast_action(
    manager: SSEBroadcastManager,  # SSEBroadcastManager to broadcast through
    event_type: Optional[str] = None,  # Event type to broadcast (defaults to function name)
    extract_data: Optional[Callable[..., Dict[str, Any]]] = None,
    broadcast_before: bool = False,  # Broadcast before executing function
    broadcast_after: bool = True,  # Broadcast after executing function
    include_result: bool = True  # Include function result in broadcast data
):
    """Decorator that broadcasts events when actions occur. This decorator automatically broadcasts SSE messages when the decorated function is called, useful for notifying clients of state changes."""
    def decorator(
        func: Callable  # The function to decorate with broadcast capability
    ):
        """Apply broadcast action decoration to a function"""
        # Use function name as default event type
        nonlocal event_type
        if event_type is None:
            event_type = func.__name__
        
        @functools.wraps(func)
        async def async_wrapper(*args, **kwargs):
            """Async wrapper that handles broadcasting for async functions"""
            # Extract data for broadcast
            data = {}
            if extract_data:
                data = extract_data(*args, **kwargs)
            
            # Broadcast before execution
            if broadcast_before:
                await manager.broadcast(
                    f"{event_type}_started",
                    data
                )
            
            # Execute the function
            try:
                result = await func(*args, **kwargs)
                
                # Broadcast after execution
                if broadcast_after:
                    broadcast_data = data.copy()
                    if include_result and result is not None:
                        if isinstance(result, dict):
                            broadcast_data.update(result)
                        else:
                            broadcast_data['result'] = result
                    
                    await manager.broadcast(
                        event_type,
                        broadcast_data
                    )
                
                return result
                
            except Exception as e:
                # Broadcast error
                await manager.broadcast(
                    f"{event_type}_error",
                    {"error": str(e), **data}
                )
                raise
        
        @functools.wraps(func)
        def sync_wrapper(*args, **kwargs):
            """Sync wrapper that handles broadcasting for sync functions"""
            # For sync functions, run in event loop
            loop = asyncio.get_event_loop()
            
            # Extract data for broadcast
            data = {}
            if extract_data:
                data = extract_data(*args, **kwargs)
            
            # Broadcast before execution
            if broadcast_before:
                loop.run_until_complete(
                    manager.broadcast(f"{event_type}_started", data)
                )
            
            # Execute the function
            try:
                result = func(*args, **kwargs)
                
                # Broadcast after execution
                if broadcast_after:
                    broadcast_data = data.copy()
                    if include_result and result is not None:
                        if isinstance(result, dict):
                            broadcast_data.update(result)
                        else:
                            broadcast_data['result'] = result
                    
                    loop.run_until_complete(
                        manager.broadcast(event_type, broadcast_data)
                    )
                
                return result
                
            except Exception as e:
                # Broadcast error
                loop.run_until_complete(
                    manager.broadcast(
                        f"{event_type}_error",
                        {"error": str(e), **data}
                    )
                )
                raise
        
        # Return appropriate wrapper
        if iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper
    
    return decorator

## Simple SSE Generator Decorator

In [None]:
#| export
def sse_generator_endpoint(
    interval: float = 1.0,  # Seconds between data checks
    event_type: Optional[str] = None,  # Optional event type for messages
    heartbeat: Optional[float] = 30.0  # Heartbeat interval in seconds
):
    """Decorator for creating simple SSE endpoints from data functions. This decorator converts a function that returns data into an SSE streaming endpoint."""
    def decorator(
        func: Callable  # The data function to convert to SSE endpoint
    ):
        """Apply SSE generator decoration to a data function"""
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            """Wrapped SSE generator endpoint handler"""
            # Create data source function
            def data_source():
                """Get data from the decorated function"""
                return func(*args, **kwargs)
            
            # Create SSE generator
            generator = sse_generator(
                data_source,
                interval=interval,
                event_type=event_type,
                heartbeat=heartbeat
            )
            
            return EventStream(generator)
        
        return wrapper
    return decorator

## Example Usage

In [None]:
# Example: Using SSE endpoint decorator with broadcast
async def example_broadcast_endpoint():
    manager = SSEBroadcastManager(debug=True)
    
    # Simulate a decorated endpoint
    @sse_endpoint(broadcast_manager=manager)
    async def updates_endpoint():
        # Could return a filter function here
        return lambda msg: msg.type != "internal"
    
    # Get the EventStream
    stream = await updates_endpoint()
    print(f"Created EventStream: {stream}")
    
    # Broadcast a test message
    await manager.broadcast("test", {"value": 42})
    
    return stream

stream = await example_broadcast_endpoint()
print(f"\nEventStream type: {type(stream).__name__}")

Created EventStream: <starlette.responses.StreamingResponse object>
[SSEBroadcastManager] No active connections to broadcast to

EventStream type: StreamingResponse


In [None]:
# Example: Custom generator endpoint
async def example_custom_generator():
    
    @sse_endpoint()
    async def custom_stream():
        for i in range(3):
            yield {"count": i}
            await asyncio.sleep(0.1)
    
    # Get the EventStream
    stream = await custom_stream()
    
    # FastHTML's EventStream handles the generator internally
    # We can't directly iterate it here, but in a real app it would
    # be returned as a response
    print(f"Created custom EventStream: {type(stream).__name__}")
    
    # To test the generator, we'll create it directly
    async def test_gen():
        for i in range(3):
            yield {"count": i}
    
    messages = []
    async for data in test_gen():
        msg = format_sse_message(data)
        messages.append(msg)
    
    print("\nGenerated messages:")
    for msg in messages:
        print(msg.strip())

await example_custom_generator()

Created custom EventStream: StreamingResponse

Generated messages:
data: {"count": 0}
data: {"count": 1}
data: {"count": 2}


In [None]:
# Example: Broadcast action decorator
async def example_broadcast_action():
    manager = SSEBroadcastManager(debug=True)
    
    @broadcast_action(
        manager,
        event_type="item_created",
        extract_data=lambda item: {"name": item["name"]}
    )
    async def create_item(item):
        # Simulate item creation
        item["id"] = "123"
        return item
    
    # Call the decorated function
    result = await create_item({"name": "Test Item"})
    print(f"Created item: {result}")
    
    # Check broadcast history
    history = manager.get_history()
    if history:
        print(f"\nBroadcast message: {history[-1].to_dict()}")

await example_broadcast_action()

[SSEBroadcastManager] No active connections to broadcast to
Created item: {'name': 'Test Item', 'id': '123'}

Broadcast message: {'type': 'item_created', 'data': {'name': 'Test Item', 'id': '123'}, 'timestamp': '2025-08-27T13:00:14.687592'}


In [None]:
# Example: Simple generator endpoint
async def example_generator_endpoint():
    counter = 0
    
    @sse_generator_endpoint(
        interval=0.1,
        event_type="counter",
        heartbeat=None
    )
    def get_counter():
        nonlocal counter
        counter += 1
        return {"value": counter} if counter <= 3 else None
    
    # Get the EventStream
    stream = await get_counter()
    print(f"Created generator EventStream: {type(stream).__name__}")
    
    # Test the underlying generator logic separately
    counter = 0
    messages = []
    for _ in range(3):
        counter += 1
        msg = format_sse_message({"value": counter}, event="counter")
        messages.append(msg)
    
    print("\nCounter messages:")
    for msg in messages:
        print(msg.strip())

await example_generator_endpoint()

Created generator EventStream: StreamingResponse

Counter messages:
event: counter
data: {"value": 1}
event: counter
data: {"value": 2}
event: counter
data: {"value": 3}


## Testing

In [None]:
# Test: EventStream with decorators
async def test_event_stream():
    # Test that our decorators work with FastHTML's EventStream
    async def simple_generator():
        yield "data: test\n\n"
        yield "data: done\n\n"
    
    stream = EventStream(simple_generator())
    
    # Check that it's a StreamingResponse (what EventStream returns)
    assert isinstance(stream, StreamingResponse)
    
    # Check the media type is correct
    assert stream.media_type == "text/event-stream"
    
    print("✓ EventStream integration tests passed")

await test_event_stream()

✓ EventStream integration tests passed


In [None]:
# Test: Decorator functionality
async def test_decorators():
    manager = SSEBroadcastManager()
    
    # Test sse_endpoint with custom generator
    @sse_endpoint()
    async def test_endpoint():
        yield {"test": 1}
    
    stream = await test_endpoint()
    assert isinstance(stream, StreamingResponse)
    assert stream.media_type == "text/event-stream"
    
    # Test broadcast_action
    events = []
    
    @broadcast_action(
        manager,
        event_type="test_action"
    )
    async def test_action(value):
        return {"result": value * 2}
    
    result = await test_action(21)
    assert result["result"] == 42
    
    # Check broadcast occurred
    history = manager.get_history()
    assert len(history) > 0
    assert history[-1].type == "test_action"
    
    print("✓ Decorator tests passed")

await test_decorators()

✓ Decorator tests passed


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()