# Execute

> TODO fill in description

In [None]:
#| default_exp runtime

In [None]:
#| hide
from nbdev.showdoc import *; 

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#|export
import asyncio
from collections import deque
from types import MappingProxyType
from typing import Type, Optional, Union, Any, Tuple, Dict
from abc import ABC, abstractmethod

import fbdev
from fbdev.packet import Packet
from fbdev.port import PortType, PortSpec, ConfigPortSpec, PortTypeSpec, PortSpecCollection, BasePort, InputPort, ConfigPort, OutputPort, PortCollection
from fbdev.component import BaseComponent, ComponentFactory, PortSpec, BasePort, InputPort, OutputPort, PortCollection, PortType
from fbdev.graph import EdgeSpec, NodeSpec, Graph, ReadonlyGraph
from fbdev.exceptions import ComponentError, NodeError, LostPacketError
from fbdev.node import Node, GraphComponentFactory

In [None]:
#|hide
from fbdev.component import func_component

In [None]:
#|hide
show_doc(fbdev.runtime.NetRuntime)

---

### NetRuntime

>      NetRuntime ()

*Helper class that provides a standard way to create an ABC using
inheritance.*

In [None]:
#|export
class NetRuntime(ABC):
    @classmethod
    def from_graph(cls, graph:Graph):
        component_type = GraphComponentFactory.get_component(graph)
        net_spec = NodeSpec(component_type)
        return cls(net_spec)
    
    @classmethod
    def execute_graph(cls, graph:Graph, *args, config={}, **kwargs):
        with cls.from_graph(graph) as netrun:
            return netrun.execute(*args, config, **kwargs)
    
    @classmethod
    async def async_execute_graph(cls, graph:Graph, *args, config={}, **kwargs):
        async with cls.from_graph(graph) as netrun:
            return await netrun.async_execute(*args, config=config, **kwargs)
    
    @abstractmethod
    def execute(self, *args, config={}, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    async def async_execute(self, *args, config={}, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    async def destroy(self):
        raise NotImplementedError()
    
    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        await self.destroy()
        
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        asyncio.run(self.destroy())

In [None]:
#|hide
show_doc(fbdev.runtime.BatchExecutor)

---

### BatchExecutor

>      BatchExecutor (net_spec:fbdev.graph.NodeSpec)

*Executes a net like a batch process (input fed in the beginning, and no input during the execution, and output is returned at the end).*

In [None]:
#|export
class BatchExecutor(NetRuntime):
    """Executes a net like a batch process (input fed in the beginning, and no input during the execution, and output is returned at the end)."""
    def __init__(self, net_spec:NodeSpec):
        super().__init__()
        self._net_spec:NodeSpec = net_spec
        self._net:Node = Node(node_spec=self._net_spec, parent_graph_process=None)
    
    def _setup_execution_coro(self, *args, config={}, **kwargs):
        if not self._net.states.initialised.get(): self._net.initialise()
        net = self._net
        
        if net.states.running.get(): raise RuntimeError("Net is already running.")
        if not net.states.initialised.get(): raise RuntimeError("Net is not initialised.")
        
        # Fill in input args
        input = {**kwargs}
        for port_name, val in zip(net.port_specs.input.keys(), args):
            if port_name in input:
                raise ValueError(f"Multiple values provided for '{port_name}'.")
            input[port_name] = val
        missing_input_args = set(net.port_specs.input.keys()) - set(input.keys())
        if len(missing_input_args) > 0:
            raise ValueError(f"Missing values for ports {missing_input_args}.")
        
        # Check for unexpected input args
        extra_input_args = set(input.keys()) - set(net.port_specs.input.keys())
        if len(extra_input_args) > 0:
            raise ValueError(f"Unexpected values for inputs {extra_input_args}.")
        
        # Check for unexpected config args
        extra_config_args = set(config.keys()) - set(net.port_specs.config.keys())
        if len(extra_config_args) > 0:
            raise ValueError(f"Unexpected values for configs {extra_config_args}.")
        
        async def input_sender():
            for port_name, val in input.items():
                await net.send_input(PortType.INPUT, port_name, val)
                
        output = {}
        async def output_receiver():
            for port_name in net.port_specs.output.keys():
                data = await net.receive_output(PortType.OUTPUT, port_name)
                output[port_name] = data
        
        async def main():
            net.run()
            await net._execute_with_exception_monitoring(input_sender(), output_receiver())
            await net._execute_with_exception_monitoring(net.async_stop())
            
        return main(), output

    def execute(self, *args, config={}, **kwargs):
        """Note: this method cannot be run from within an event loop."""
        exec_coro, output = self._setup_execution_coro(*args, config=config, **kwargs)
        asyncio.run(exec_coro)
        return output
    
    async def async_execute(self, *args, config={}, **kwargs):
        exec_coro, output = self._setup_execution_coro(*args, config=config, **kwargs)
        await exec_coro
        return output
    
    async def destroy(self):
        if self._net is not None:
            await self._net.destroy()
            self._net = None
    
    def __del__(self):
        if self._net is not None:
            #TODO log a warning or handle the case where close() was not called
            pass

In [None]:
@func_component()
def add_one(a:int) -> int:
    print("In add_one")
    return a+1

@func_component()
def copier(a:int) -> Tuple[int, int]:
    print("In copier")
    return a, a

@func_component()
def printer(a:int) -> None:
    print("In printer1:", a)
    
@func_component()
def sender(a:int):
    print("In sender")
    return a
    
g = Graph(PortSpecCollection(
    input=PortTypeSpec(in1=PortSpec(dtype=int)),
    output=PortTypeSpec(out=PortSpec(dtype=int))
))

g.add_node(NodeSpec(add_one))
g.add_node(NodeSpec(copier))
g.add_node(NodeSpec(printer))
g.add_node(NodeSpec(sender))

g.add_edge(EdgeSpec())
g.add_edge(EdgeSpec())
g.add_edge(EdgeSpec())
g.add_edge(EdgeSpec())
g.add_edge(EdgeSpec())

g.connect_edge_to_graph_port(PortType.INPUT, 'in1', 0)
g.connect_edge_to_node('add_one', PortType.INPUT, 'a', 0)

g.connect_edge_to_node('add_one', PortType.OUTPUT, 'out', 1)
g.connect_edge_to_node('copier', PortType.INPUT, 'a', 1)

g.connect_edge_to_node('copier', PortType.OUTPUT, 'out0', 2)
g.connect_edge_to_node('printer', PortType.INPUT, 'a', 2)

g.connect_edge_to_node('copier', PortType.OUTPUT, 'out1', 3)
g.connect_edge_to_node('sender', PortType.INPUT, 'a', 3)

g.connect_edge_to_node('sender', PortType.OUTPUT, 'out', 4)
g.connect_edge_to_graph_port(PortType.OUTPUT, 'out', 4)

output = await BatchExecutor.async_execute_graph(g, 1)
print("Output:", output['out'])

In add_one
In copier
In printer1: 2
In sender
Output: 2
