In [1]:
import asyncio
import aiochan as ac

In [5]:
from typing import Any, Callable, Literal, Optional, TypedDict

class DispatchOpts(TypedDict):
    buffer_size: int
    batch_size: int
    batch_timeout: float

default_dispatch_opts = DispatchOpts(
    buffer_size=100,
    batch_size=1,
    batch_timeout=0.1,
)

DispatchPayload = tuple[Callable, tuple[Any, ...], dict[str, Any], ac.Chan]
ParallelExecutorMode = Literal['thread', 'process']

class Dispatcher:
    dispatch_map: dict[Callable, tuple[ac.Chan, dict[str, Any]]]
    dispatch_channel: ac.Chan   # ac.Chan[DispatchPayload]
    publisher: ac.Pub   # ac.Pub[DispatchPayload]
    loop: asyncio.AbstractEventLoop
    parallel_executor_mode: ParallelExecutorMode

    def __init__(
        self,
        *,
        buffer_size: int = 100,
        parallel_executor_mode: ParallelExecutorMode = 'thread',
        loop: Optional[asyncio.AbstractEventLoop] = None,
    ):
        # Set up the event loop
        self.loop = loop or asyncio.new_event_loop()
        self.parallel_executor_mode = parallel_executor_mode

        # Set up the dispatch channel
        self.dispatch_map: dict[Callable, tuple[ac.Chan, dict[str, Any]]] = {}
        self.dispatch_channel = ac.Chan(
            buffer_size=buffer_size,
            loop=self.loop,
            name="dispatcher.publisher"
        )

        self.publisher = self.dispatch_channel.pub(
            topic_fn=lambda payload: payload[0],
            buffer_size=buffer_size,
        )

    def close(self):
        self.publisher.close()
        self.dispatch_channel.close()

    def unregister(self, func):
        assert func in self.dispatch_map, "Dispatch target not registered"

        # Get the input channel for this function
        in_ch, _ = self.dispatch_map[func]

        # Unsubscribe the publisher from the input channel
        self.publisher.unsub(func, in_ch)
        in_ch.close()

        # Remove the function from the dispatch map
        del self.dispatch_map[func]

    def register(self, func, opts: DispatchOpts = default_dispatch_opts) -> 'Dispatcher':
        assert func not in self.dispatch_map, "Dispatch target already registered"

        buffer_size = opts.get('buffer_size', default_dispatch_opts['buffer_size'])

        # Create the input channel for this function
        in_ch = ac.Chan(
            buffer_size=buffer_size,
            loop=self.loop,
        )

        # Add the function to the dispatch map
        self.dispatch_map[func] = (in_ch, opts)

        return self

    async def dispatch(self, func, *args, **kwargs) -> Any:
        assert func in self.dispatch_map, "Dispatch target not registered"

        in_ch, opts = self.dispatch_map[func]
        out_ch = ac.Chan(1, loop=self.loop)

        await in_ch.put((args, kwargs, out_ch))

        result = await out_ch.get()
        out_ch.close()

        return result

    async def worker(self):
        ...
        