# graph.GraphComponentFactory

> TODO fill in description

In [None]:
#| default_exp graph.graph_component

In [None]:
#| hide
from nbdev.showdoc import *;

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#|export
from __future__ import annotations
import asyncio
from abc import abstractmethod
from types import MappingProxyType
from typing import Type, Tuple, Dict

import fbdev
from fbdev.comp.port import PortType, PortSpec, PortSpecCollection, PortID
from fbdev.comp.base_component import BaseComponent
from fbdev.graph.graph_spec import GraphSpec, NodeSpec
from fbdev.graph.packet_registry import PacketRegistry
from fbdev.graph.net import Edge, BaseNode
from fbdev.exceptions import NodeError, EdgeError

In [None]:
#|hide
from fbdev.comp.packet import Packet
from fbdev.graph.net import Node

In [None]:
#|hide
show_doc(fbdev.graph.graph_component.GraphComponentFactory)

---

### GraphComponentFactory

>      GraphComponentFactory ()

*Helper class that provides a standard way to create an ABC using
inheritance.*

In [None]:
#|export
class GraphComponentFactory(BaseComponent, inherit_ports=False):
    is_factory = True
    expose_graph = True
    graph: GraphSpec = None
    
    port_specs = PortSpecCollection()
    
    def __init__(self):
        super().__init__()
        self._parent_node: BaseNode = None # Must be set by its node in BaseNode.start()
        self._nodes: Dict[str, Node] = {}
        self._edges: Dict[str, Edge] = {}
            
    @property
    def nodes(self) -> MappingProxyType[str, Node]: return MappingProxyType(self._nodes)
    @property
    def edges(self) -> MappingProxyType[str, Edge]: return MappingProxyType(self._edges)
    
    @property
    def _packet_registry(self) -> PacketRegistry: return self._parent_net._packet_registry
    
    def _handle_node_exception(self, task:asyncio.Task, exceptions:Tuple[Exception, ...], source_trace:Tuple):
        try: raise NodeError() from exceptions[0]
        except NodeError as e: self.task_manager.submit_exception(task, exceptions + (e,), source_trace)
    
    def _handle_edge_exception(self, task:asyncio.Task, exceptions:Tuple[Exception, ...], source_trace:Tuple):
        try: raise EdgeError() from exceptions[0]
        except EdgeError as e: self.task_manager.submit_exception(task, exceptions + (e,), source_trace)
    
    @classmethod
    def create_component(cls, graph, expose_graph=True) -> Type[BaseComponent]:
        graph = graph.copy()
        graph.make_readonly()
        return cls._create_component_class(class_attrs={
            'graph' : graph,
            'expose_graph' : expose_graph,
            'port_specs' : graph._port_specs
        })
        
    async def _post_start(self):
        for node_spec in self.graph.nodes.values():
            self._nodes[node_spec.id] = node_spec.create_node(self._parent_node)
            self._nodes[node_spec.id].task_manager.subscribe(self._handle_node_exception)
        for edge_spec in self.graph.edges.values():
            self._edges[edge_spec.id] = Edge(edge_spec, self)
            self._edges[edge_spec.id].task_manager.subscribe(self._handle_edge_exception)
        
        for node in self._nodes.values():
            await node.start()
        for edge in self.edges.values():
            edge.start()
            
    async def _pre_stop(self):
        for node in self._nodes.values():
            await node.stop()
        for edge in self.edges.values():
            await edge.stop()

In [None]:
class BaseFooComponent(BaseComponent):
    @abstractmethod
    async def main(self): ...
    async def _post_start(self): self.task_manager.create_task(self.main())

class FooComponent1(BaseFooComponent):
    port_specs = PortSpecCollection(
        PortSpec(PortType.INPUT, "inp"),
        PortSpec(PortType.OUTPUT, "out"),
    )
    async def main(self):
        packet = await self.ports.input.inp.get()
        print(await packet.consume())
        await self.ports.output.out.put(Packet('there'))
        
class FooComponent2(BaseFooComponent):
    port_specs = PortSpecCollection(
        PortSpec(PortType.INPUT, "inp"),
        PortSpec(PortType.OUTPUT, "out"),
    )
    async def main(self):
        packet = await self.ports.input.inp.get()
        print(await packet.consume())
        await self.ports.output.out.put(Packet('world'))
        
graph = GraphSpec(PortSpecCollection())

graph.add_graph_port(PortSpec(PortType.INPUT, "inp"))
graph.add_graph_port(PortSpec(PortType.OUTPUT, "out"))

node1 = graph.add_node(FooComponent1)
node2 = graph.add_node(FooComponent2)

graph.ports.input.inp >> node1.ports.input.inp
node1 >> node2
node2.ports.output.out >> graph.ports.output.out

graph.display_mermaid(hide_unconnected_ports=True)

```mermaid
flowchart 
    subgraph FooComponent1["FooComponent1[]"]
        FooComponent1__C__input.inp[inp]
        FooComponent1__C__output.out[out]
    end
    subgraph FooComponent2["FooComponent2[]"]
        FooComponent2__C__input.inp[inp]
        FooComponent2__C__output.out[out]
    end
    GRAPH__C__message.started[[started]]
    GRAPH__C__message.stopped[[stopped]]
    GRAPH__C__input.inp[inp]
    GRAPH__C__output.out[out]
    FooComponent1__C__output.out --> FooComponent2__C__input.inp
    FooComponent2__C__output.out -.-> GRAPH__C__output.out
    GRAPH__C__input.inp -.-> FooComponent1__C__input.inp
    classDef input fill:#13543e;
    classDef output fill:#0d1b59;
    classDef subgraph_zone fill:#000;
    class FooComponent1__C__input.inp,FooComponent2__C__input.inp,GRAPH__C__input.inp input;
    class FooComponent1__C__output.out,FooComponent2__C__output.out,GRAPH__C__output.out output;
    class GRAPH__C__message.started,GRAPH__C__message.stopped message;
```

In [None]:
graph_component = fbdev.graph.GraphComponentFactory.create_component(graph)
net_spec = NodeSpec(graph_component)
net = Node(net_spec)

packet = Packet('hello')

async def get_output():
    packet = await net.ports[(PortType.OUTPUT, 'out')].get()
    print(await packet.consume())

await net.start()
await net.task_manager.exec_coros(
    net.ports[(PortType.INPUT, 'inp')].put(packet),
    get_output(),
)
await net.stop()

hello
there
world
