# runtime.batch_executor

> TODO fill in description

In [None]:
#| default_exp runtime.batch_executor

In [None]:
#| hide
from nbdev.showdoc import *; 

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#|export
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import Type, Tuple, Dict, List

import fbdev
from fbdev.exceptions import NodeError, EdgeError
from fbdev.comp.packet import Packet
from fbdev.comp.port import PortType, PortSpec, PortSpecCollection, PortID
from fbdev.comp.base_component import BaseComponent
from fbdev.graph.graph_spec import GraphSpec, NodeSpec
from fbdev.graph.packet_registry import TrackedPacket
from fbdev.graph.net import Edge, Node, BaseNode
from fbdev.graph.graph_component import GraphComponentFactory
from fbdev.runtime import BaseRuntime
from fbdev.runtime._utils import parse_args_into_port_packets, setup_packet_senders_and_receivers

In [None]:
#|hide
show_doc(fbdev.runtime.batch_executor.BatchExecutor)

---

### BatchExecutor

>      BatchExecutor (net:Net)

*Executes a net like a batch process (input fed in the beginning, and no input during the execution, and output is returned at the end).*

In [None]:
#|export
class BatchExecutor(BaseRuntime):
    """Executes a net like a batch process (input fed in the beginning, and no input during the execution, and output is returned at the end)."""
    def __init__(self, node:BaseNode):
        super().__init__()
        self._node:BaseNode = node
    
    def _setup_execution(self, *args, config_vals={}, signals=set(), ports_to_get=None, **kwargs):
        if self._node.states.started.get(): raise RuntimeError("Node has already started.")
        if self._node.states.stopped.get(): raise RuntimeError("Cannot run stopped node.")
        
        if ports_to_get is None:
            ports_to_get = [port.id for port in self._node.ports.output.values()]
        
        input_vals, config_vals, signals = parse_args_into_port_packets(self._node.port_specs, config_vals, signals, *args, **kwargs)
        
        output_vals, message_vals, input_senders, config_senders, output_receivers, message_receivers = \
            setup_packet_senders_and_receivers(self._node.ports, input_vals, config_vals, ports_to_get, *args, **kwargs)
        
        async def main():
            await self._node.start()
            await self._node.task_manager.exec_coros(*input_senders, *config_senders, *output_receivers, *message_receivers)
            await self._node.task_manager.exec_coros(self._node.stop())
            
        return main(), output_vals

    def start(self, *args, config={}, signals=set(), ports_to_get:List[PortID]|None=None, **kwargs):
        """Note: this method cannot be run from within an event loop."""
        super().start()
        coro, output = self._setup_execution(*args, config_vals=config, signals=signals, ports_to_get=ports_to_get, **kwargs)
        asyncio.run(coro)
        self._started = True
        return output
    
    async def astart(self, *args, config={}, signals=set(), ports_to_get:List[PortID]|None=None, **kwargs):
        await super().astart()
        coro, output = self._setup_execution(*args, config_vals=config, signals=signals, ports_to_get=ports_to_get, **kwargs)
        await coro
        self._started = True
        return output
    
    async def stop(self):
        await super().stop()
        if not self._node.states.stopped.get():
            if not self._node.states.stopped.get():
                await self._node.task_manager.exec_coros(self._node.stop())
        self._stopped = True

In [None]:
class BaseFooComponent(BaseComponent):
    @abstractmethod
    async def main(self): ...
    async def _post_start(self): self.task_manager.create_task(self.main())

class FooComponent1(BaseFooComponent):
    port_specs = PortSpecCollection(
        PortSpec(PortType.INPUT, "inp"),
        PortSpec(PortType.OUTPUT, "out"),
    )
    async def main(self):
        packet = await self.ports.input.inp.get()
        print(await packet.consume())
        await self.ports.output.out.put(Packet('there'))
        
class FooComponent2(BaseFooComponent):
    port_specs = PortSpecCollection(
        PortSpec(PortType.INPUT, "inp"),
        PortSpec(PortType.OUTPUT, "out"),
    )
    async def main(self):
        packet = await self.ports.input.inp.get()
        print(await packet.consume())
        await self.ports.output.out.put(Packet('world'))
        
graph = GraphSpec(PortSpecCollection())

graph.add_graph_port(PortSpec(PortType.INPUT, "inp"))
graph.add_graph_port(PortSpec(PortType.OUTPUT, "out"))

node1 = graph.add_node(FooComponent1)
node2 = graph.add_node(FooComponent2)

graph.ports.input.inp >> node1.ports.input.inp
node1 >> node2
node2.ports.output.out >> graph.ports.output.out

graph.display_mermaid(hide_unconnected_ports=True)

```mermaid
flowchart 
    subgraph FooComponent1["FooComponent1[]"]
        FooComponent1__C__input.inp[inp]
        FooComponent1__C__output.out[out]
    end
    subgraph FooComponent2["FooComponent2[]"]
        FooComponent2__C__input.inp[inp]
        FooComponent2__C__output.out[out]
    end
    GRAPH__C__message.started[[started]]
    GRAPH__C__message.stopped[[stopped]]
    GRAPH__C__input.inp[inp]
    GRAPH__C__output.out[out]
    GRAPH__C__input.inp -.-> FooComponent1__C__input.inp
    FooComponent2__C__output.out -.-> GRAPH__C__output.out
    FooComponent1__C__output.out --> FooComponent2__C__input.inp
    classDef input fill:#13543e;
    classDef output fill:#0d1b59;
    classDef subgraph_zone fill:#000;
    class FooComponent1__C__input.inp,FooComponent2__C__input.inp,GRAPH__C__input.inp input;
    class FooComponent1__C__output.out,FooComponent2__C__output.out,GRAPH__C__output.out output;
    class GRAPH__C__message.started,GRAPH__C__message.stopped message;
```

In [None]:
async with BatchExecutor.from_graph(graph) as ex:
    res = await ex.astart(5)
    print(res)

5
there
{'out': 'world'}
