In [1]:
import asyncio
from asyncio import AbstractEventLoop, Future, Task
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, TypedDict

import aiochan as ac

# Utils
create_scanner = lambda in_ch: in_ch.scan(lambda batch, item: [*batch, item], init=[])

class DispatchOpts(TypedDict):
    buffer_size: int
    batch_size: int
    batch_timeout: float

default_dispatch_opts = DispatchOpts(
    buffer_size=100,
    batch_size=1,
    batch_timeout=0.1,
)

DispatchPayload = Tuple[Callable, Tuple[Any, ...], Dict[str, Any], ac.Chan]

class Dispatcher:
    # Map to store the input channels for each function
    dispatch_map: Dict[Callable, Tuple[ac.Chan, Dict[str, Any]]]

    # Channel to receive dispatch requests
    dispatch_channel: ac.Chan   # ac.Chan[DispatchPayload]

    # Publisher to publish to the input channels
    publisher: ac.Pub   # ac.Pub[DispatchPayload]

    # Event loop
    loop: AbstractEventLoop

    def __init__(
        self,
        *,
        buffer_size: int = 100,
        loop: Optional[AbstractEventLoop] = None,
    ):
        # Set up the event loop
        self.loop = loop or asyncio.new_event_loop()

        # Set up the dispatch channel
        self.dispatch_map: dict[Callable, tuple[ac.Chan, DispatchOpts]] = {}
        self.dispatch_channel = ac.Chan(
            buffer_size=buffer_size,
            loop=self.loop,
            name="dispatcher.publisher"
        )

        # Set up the publisher
        self.publisher = self.dispatch_channel.pub(
            topic_fn=lambda payload: payload[0],
            buffer_size=buffer_size,
        )

    def close(self) -> None:
        # Close the publisher and dispatch channel
        self.publisher.close()
        self.dispatch_channel.close()

    def unregister(self, func: Callable) -> None:
        assert func in self.dispatch_map, "Dispatch target not registered"

        # Get the input channel for this function
        in_ch, _ = self.dispatch_map[func]

        # Unsubscribe the publisher from the input channel
        self.publisher.unsub(func, in_ch)
        in_ch.close()

        # Remove the function from the dispatch map
        del self.dispatch_map[func]

    async def dispatch(self, func: Callable, *args, **kwargs) -> Future[Any]:
        assert func in self.dispatch_map, "Dispatch target not registered"

        # Create a channel to receive the result
        out_ch = ac.Chan(1, loop=self.loop)

        # Dispatch the function call to the worker
        await self.dispatch_channel.put((func, args, kwargs, out_ch))

        # Create future
        future = self.loop.create_future()

        # Set the future result when the result is available
        async def set_future() -> None:
            result = await out_ch.get()
            out_ch.close()

            future.set_result(result)

        # Schedule the future to be set
        self.loop.create_task(set_future())
        
        return future

    async def _worker(self, func) -> None:
        # Unpack the input channel and options
        in_ch, opts = self.dispatch_map[func]
        batch_size = opts.get('batch_size', default_dispatch_opts['buffer_size'])
        batch_timeout = opts.get('batch_timeout', default_dispatch_opts['batch_timeout'])

        # Ticker and scanner
        scanner = create_scanner(in_ch, batch_size)
        ticker = ac.tick_tock(batch_timeout, loop=self.loop)

        # Placeholder
        batch = []

        async def process_batch(dispatch_payloads: list[]) -> None:
            # Unpack the batch
            func, args, kwargs, out_ch = batch
            # Dispatch the batch to the function
            result = await func(batch)

            # Publish the result
            await self.publisher.pub(result)

        while (selected := ac.select(scanner, ticker)) is not None:
            # Scanner selected
            if isinstance(selected, list):
                batch = selected

            # If the ticker was selected, it means that the batch timeout has expired
            # Or if batch exceeds the batch size, dispatch the batch
            if isinstance(selected, tuple) or len(batch) >= batch_size:
                ...



    def register(self, func, opts: DispatchOpts = default_dispatch_opts) -> 'Dispatcher':
        assert func not in self.dispatch_map, "Dispatch target already registered"

        buffer_size = opts.get('buffer_size', default_dispatch_opts['buffer_size'])

        # Create the input channel for this function
        in_ch = ac.Chan(
            buffer_size=buffer_size,
            loop=self.loop,
        )

        # Add the function to the dispatch map
        self.dispatch_map[func] = (in_ch, opts)

        return self

        