# runtime.batch_executor

> TODO fill in description

In [None]:
#| default_exp runtime.batch_executor

In [None]:
#| hide
from nbdev.showdoc import *; 

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#|export
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import Type, Tuple, Dict

import fbdev
from fbdev.exceptions import NodeError, EdgeError
from fbdev.comp.packet import Packet
from fbdev.comp.port import PortType, PortSpec, PortSpecCollection, PortID
from fbdev.comp.base_component import BaseComponent
from fbdev.graph.graph_spec import GraphSpec, NodeSpec
from fbdev.graph.packet_registry import TrackedPacket
from fbdev.graph.net import Edge, Node, Net
from fbdev.graph.graph_component import GraphComponentFactory
from fbdev.runtime import BaseNetRuntime

In [None]:
#|hide
show_doc(fbdev.runtime.batch_executor.BatchExecutor)

---

### BatchExecutor

>      BatchExecutor (net:Net)

*Executes a net like a batch process (input fed in the beginning, and no input during the execution, and output is returned at the end).*

In [None]:
#|export
class BatchExecutor(BaseNetRuntime):
    """Executes a net like a batch process (input fed in the beginning, and no input during the execution, and output is returned at the end)."""
    def __init__(self, net:Net):
        super().__init__()
        self._net:Net = net
    
    def _setup_execution_coro(self, *args, config_vals={}, **kwargs):
        net = self._net
        
        if net.states.started.get(): raise RuntimeError("Net has already started.")
        if net.states.terminated.get(): raise RuntimeError("Cannot run terminated Net.")
        
        # Fill in input args
        input_vals = {**kwargs}
        for port_name, val in zip(net.port_specs.input.keys(), args):
            if port_name in input_vals: raise ValueError(f"Multiple values provided for '{port_name}'.")
            input_vals[port_name] = val
        missing_input_args = set(net.port_specs.input.keys()) - set(input_vals.keys())
        if len(missing_input_args) > 0:
            raise ValueError(f"Missing values for ports '{missing_input_args}'.")
        
        # Check for unexpected input args
        extra_input_args = set(input_vals.keys()) - set(net.port_specs.input.keys())
        if len(extra_input_args) > 0:
            raise ValueError(f"Unexpected values for inputs '{extra_input_args}'.")
        
        # Check for unexpected config args
        extra_config_args = set(config_vals.keys()) - set(net.port_specs.config.keys())
        if len(extra_config_args) > 0:
            raise ValueError(f"Unexpected values for configs '{extra_config_args}'.")
                
        async def packet_sender(port_id: PortID, val):
            await net.component_process.ports[port_id]._put(Packet(val))
            
        output_vals = {}
        async def packet_receiver(port_id: PortID):
            packet = await net.component_process.ports[port_id]._get()
            output_vals[port_id[1]] = await packet.consume()
        
        input_senders = [packet_sender(port.id, input_vals[port.name]) for port in net.ports.input.values()]
        config_senders = [packet_sender(port.id, config_vals[port.name]) for port in net.ports.config.values()]
        output_receivers = [packet_receiver(port.id) for port in net.ports.output.values()]
        
        async def main():
            await net.start()
            await net.exec_coros(*input_senders, *config_senders, *output_receivers)
            await net.exec_coros(net.terminate())
            
        return main(), output_vals

    def execute(self, *args, config={}, **kwargs):
        """Note: this method cannot be run from within an event loop."""
        exec_coro, output = self._setup_execution_coro(*args, config_vals=config, **kwargs)
        asyncio.run(exec_coro)
        return output
    
    async def aexecute(self, *args, config={}, **kwargs):
        exec_coro, output = self._setup_execution_coro(*args, config_vals=config, **kwargs)
        await exec_coro
        return output
    
    async def stop(self):
        if self._net is not None:
            if not self._net.states.terminated.get():
                await self._net.terminate()
            self._net = None

In [None]:
class BaseFooComponent(BaseComponent):
    @abstractmethod
    async def main(self): ...
    async def _post_start(self): self._task_manager.create_task(self.main())

class FooComponent1(BaseFooComponent):
    port_specs = PortSpecCollection(
        PortSpec(PortType.INPUT, "inp"),
        PortSpec(PortType.OUTPUT, "out"),
    )
    async def main(self):
        packet = await self.ports.input.inp.get()
        print(await packet.consume())
        await self.ports.output.out.put(Packet('there'))
        
class FooComponent2(BaseFooComponent):
    port_specs = PortSpecCollection(
        PortSpec(PortType.INPUT, "inp"),
        PortSpec(PortType.OUTPUT, "out"),
    )
    async def main(self):
        packet = await self.ports.input.inp.get()
        print(await packet.consume())
        await self.ports.output.out.put(Packet('world'))
        
graph = GraphSpec(PortSpecCollection())

graph.add_graph_port(PortSpec(PortType.INPUT, "inp"))
graph.add_graph_port(PortSpec(PortType.OUTPUT, "out"))

node1 = graph.add_node(FooComponent1)
node2 = graph.add_node(FooComponent2)

graph.ports.input.inp >> node1.ports.input.inp
node1 >> node2
node2.ports.output.out >> graph.ports.output.out

graph.display_mermaid(hide_unconnected_ports=True)

```mermaid
flowchart 
    subgraph FooComponent1["FooComponent1[]"]
        FooComponent1__C__input__D__inp[inp]
        FooComponent1__C__output__D__out[out]
    end
    subgraph FooComponent2["FooComponent2[]"]
        FooComponent2__C__input__D__inp[inp]
        FooComponent2__C__output__D__out[out]
    end
    GRAPH__C__input__D__inp[inp]
    GRAPH__C__output__D__out[out]
    GRAPH__C__input__D__inp -.-> FooComponent1__C__input__D__inp
    FooComponent1__C__output__D__out --> FooComponent2__C__input__D__inp
    FooComponent2__C__output__D__out -.-> GRAPH__C__output__D__out
    classDef input fill:#13543e;
    classDef output fill:#0d1b59;
    classDef subgraph_zone fill:#000;
    class FooComponent1__C__input__D__inp,FooComponent2__C__input__D__inp,GRAPH__C__input__D__inp input;
    class FooComponent1__C__output__D__out,FooComponent2__C__output__D__out,GRAPH__C__output__D__out output;
```

In [None]:
async with BatchExecutor.from_graph(graph) as ex:
    res = await ex.aexecute(5)
    print(res)

5
there
{'out': 'world'}
