# component.port

> TODO fill in description

In [None]:
#| default_exp comp.port

In [None]:
#| hide
from nbdev.showdoc import *; 

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#|export
from __future__ import annotations
import asyncio
from enum import Enum
from typing import List, Dict, Callable, Any, Tuple, NewType, Iterator
from types import MappingProxyType
from abc import ABC, abstractmethod

import fbdev
from fbdev.comp.packet import BasePacket, Packet
from fbdev._utils import SingletonMeta, AttrContainer, StateHandler, StateCollection, is_valid_name

In [None]:
#|hide
show_doc(fbdev.comp.port.PortType)

---

### PortType

>      PortType (value, names=None, module=None, qualname=None, type=None,
>                start=1)

*An enumeration.*

In [None]:
#|export
class PortType(Enum):
    INPUT = ("input", True)
    CONFIG = ("config", True)
    SIGNAL = ("signal", True)
    
    OUTPUT = ("output", False)
    MESSAGE = ("message", False)
    
    def __init__(self, label:str, is_input_port:bool):
        self._label:str = label
        self._is_input_port:bool = is_input_port
        
    @property
    def label(self) -> str: return self._label
    @property
    def is_input_port(self) -> bool: return self._is_input_port
    
    def get(self, port_type_label:str) -> PortType:
        for port_type in self:
            if port_type.label == port_type_label:
                return port_type
        raise RuntimeError(f"Port type {port_type_label} does not exist.")

In [None]:
#|export
PortID = NewType('PortID', Tuple[PortType, str])

In [None]:
#|hide
show_doc(fbdev.comp.port.PortSpec)

---

### PortSpec

>      PortSpec (port_type, name=None, dtype=None, data_validator=None,
>                is_optional=False, default=<fbdev._utils.NO_DEFAULT object at
>                0x1209b60d0>)

*Initialize self.  See help(type(self)) for accurate signature.*

In [None]:
#|export
class PortSpec:
    _NO_DEFAULT = SingletonMeta('NO_DEFAULT')
    
    def __init__(self, port_type, name=None, dtype=None, data_validator=None, is_optional=False, default=_NO_DEFAULT()):
        self._name:str = name
        self._port_type:PortType = port_type
        self._dtype:type = dtype
        self._data_validator:Callable[[Any], bool] = data_validator
        self._is_optional = is_optional
        self._default = default
        
        if dtype is not None and type(dtype) != type:
            raise ValueError("Argument `dtype` must be a type")
        
        if port_type == PortType.SIGNAL:
            if dtype is not None: raise RuntimeError(f"Signal port {self.name} cannot have a dtype.")
            if data_validator is not None: raise RuntimeError(f"Signal port {self.name} cannot have a data validator.")
        
        if port_type != PortType.CONFIG:
            if is_optional:
                raise RuntimeError(f"Only ports of type {PortType.CONFIG} can be optional.")
            if type(default) != PortSpec._NO_DEFAULT:
                raise RuntimeError(f"Only ports of type {PortType.CONFIG} can have a default value.")
        
        if self.is_optional and self.has_default:
            raise RuntimeError("Config port {self.name} cannot have both be optional and have a default value.")
            
    @property
    def name(self) -> str: return self._name
    @property
    def id(self) -> PortID: return (self._port_type, self._name)
    @property
    def id_str(self) -> PortID: return f"{self._port_type.label}.{self._name}"
    @property
    def port_type(self) -> PortType: return self._port_type
    @property
    def is_input_port(self) -> bool: return self._port_type.is_input_port
    @property
    def is_output_port(self) -> bool: return not self.is_input_port
    @property
    def dtype(self) -> type: return self._dtype
    @property
    def data_validator(self) -> Callable[[Any], bool]: return self._data_validator
    
    @property
    def has_dtype(self) -> bool: return self._dtype is not None
    @property
    def has_data_validator(self) -> bool: return self._data_validator is not None
    
    @property
    def is_optional(self) -> bool: return self._is_optional
    @property
    def default(self) -> Any:
        if not self.has_default: raise RuntimeError(f"Config port {self.name} does not have a default value.")
        return self._default
    @property
    def has_default(self) -> bool: return type(self._default) != PortSpec._NO_DEFAULT
    
    def __str__(self) -> str:
        return f"{self.port_type.label}.{self.name}"
    
    def __repr__(self) -> str:
        return str(self)

    def copy(self):
        if self.has_default:
            port_spec = PortSpec(
                self._port_type,
                self._name,
                self._dtype,
                self._data_validator,
                self._is_optional,
                self._default
            )
        else:
            port_spec = PortSpec(
                self._port_type,
                self._name,
                self._dtype,
                self._data_validator,
                self._is_optional
            )
        return port_spec

In [None]:
#|hide
show_doc(fbdev.comp.port.PortSpecCollection)

---

### PortSpecCollection

>      PortSpecCollection (*port_specs:List[PortSpec])

*Initialize self.  See help(type(self)) for accurate signature.*

In [None]:
#|export
class PortSpecCollection:
    def __init__(self, *port_specs:List[PortSpec]):
        self._readonly:bool = False
        self._ports: Dict[PortID, PortSpec] = {}
        for port_type in PortType:
            setattr(self, port_type.label, AttrContainer({}, obj_name=f"{PortSpecCollection.__name__}.{port_type.label}"))
        for port_spec in port_specs:
            if port_spec.name is None: raise ValueError("PortSpec.name is None.")
            if not isinstance(port_spec, PortSpec): raise TypeError(f"PortSpecCollection can only contain PortSpecs. Got '{type(port_spec)}'.")
            self.add_port(port_spec)
    
    def __getitem__(self, key:PortID) -> PortSpec:
        if key in self._ports: return self._ports[key]
        else: raise KeyError(f"'{key}' does not exist in {self.__class__.__name__}.")
    
    def __iter__(self): return self._ports.__iter__()
    def __len__(self): return self._ports.__len__()
    def __contains__(self, key): return key in self._ports
    def as_dict(self) -> Dict[str, PortSpec]: return MappingProxyType(self._ports)
    def iter_ports(self) -> Iterator[PortSpec]: return self._ports.values()
    
    def make_readonly(self): self._readonly = True
    
    def add_port(self, port_spec:PortSpec):
        if self._readonly: raise RuntimeError("Cannot add ports to a readonly PortSpecCollection.")
        if not is_valid_name(port_spec.name): raise ValueError(f"Invalid port name '{port_spec.name}'.")
        if port_spec.id in self._ports: raise ValueError(f"Port name '{port_spec.name}' already exists in {self.__class__.__name__}.")
        self._ports[port_spec.id] = port_spec
        
        name_parts = port_spec.name.split('.')
        name_stem = name_parts.pop()
        attr_container = getattr(self, port_spec.port_type.label)
        attr_container_addr = f"{PortSpecCollection.__name__}.{port_spec.port_type.label}"
        for name_part in name_parts:
            attr_container_addr += f".{name_part}"
            if not name_part in attr_container:
                attr_container._set(name_part, AttrContainer({}, obj_name=attr_container_addr))
            attr_container = attr_container[name_part]
        attr_container._set(name_stem, port_spec)
    
    def remove_port(self, port_spec:PortSpec):
        if self._readonly: raise RuntimeError("Cannot remove ports from a readonly PortSpecCollection.")
        if port_spec.id not in self._ports: raise ValueError(f"Port name '{port_spec.name}' does not exist in {self.__class__.__name__}.")
        del self._ports[port_spec.id]
        getattr(self, port_spec.port_type.label)._remove(port_spec.name)
        
    def update(self, parent:PortSpecCollection):
        if self._readonly: raise RuntimeError("Cannot add ports to a readonly PortSpecCollection.")
        for port in parent._ports.values():
            self.add_port(port)
        
    def copy(self) -> PortSpecCollection:
        """Note: The copy is not readonly."""
        port_spec_collection = PortSpecCollection(
            *[port_spec.copy() for port_spec in self._ports.values()]
        )
        return port_spec_collection
        
    def __str__helper(self, attr_container:AttrContainer, lines:List[str], indent:str=''):
        for key, value in attr_container.items():
            if isinstance(value, AttrContainer):
                lines.append(f"{indent}{key}:")
                self.__str__helper(value, lines, indent + "  ")
            else: lines.append(f"{indent}{key}")
        
    def __str__(self) -> str:
        lines = []
        for port_type in PortType:
            if len(getattr(self, port_type.label)) == 0: continue
            lines.append(f"{port_type.label}:")
            self.__str__helper(getattr(self, port_type.label), lines, "  ")
        return "\n".join(lines)
    
    def __repr__(self):
        return self.__str__()
    

In [None]:
PortSpecCollection(
    PortSpec(PortType.INPUT,'in1'),
    PortSpec(PortType.INPUT,'a_port_subgroup.in1'),
    PortSpec(PortType.INPUT,'a_port_subgroup.in2'),
    PortSpec(PortType.OUTPUT,'out1', dtype=int),
    PortSpec(PortType.CONFIG,'conf1', dtype=str, default=''),
)

input:
  in1
  a_port_subgroup:
    in1
    in2
config:
  conf1
output:
  out1

In [None]:
#|hide
show_doc(fbdev.comp.port.BasePort)

---

### BasePort

>      BasePort ()

*Helper class that provides a standard way to create an ABC using
inheritance.*

In [None]:
#|export
class BasePort(ABC):
    @property
    @abstractmethod
    def spec(self) -> PortSpec: ...
    @property
    @abstractmethod
    def name(self) -> str: ...
    @property
    @abstractmethod
    def id(self) -> str: ...
    @property
    @abstractmethod
    def port_type(self) -> PortType: ...
    @property
    @abstractmethod
    def dtype(self) -> type: ...
    @property
    @abstractmethod
    def is_input_port(self) -> bool: ...
    @property
    @abstractmethod
    def is_output_port(self) -> bool: ...
    @property
    @abstractmethod
    def data_validator(self) -> Callable[[Any], bool]: ...
    @property
    @abstractmethod
    def states(self) -> StateCollection: ...
        
    @abstractmethod
    async def _put(self, packet:BasePacket): ...
    
    @abstractmethod
    async def _get(self) -> BasePacket: ...
    
    async def _put_value(self, val:Any):
        await self._put(Packet(val))
        
    async def _get_and_consume(self) -> Any:
        packet: BasePacket = await self._get()
        return await packet.consume()

In [None]:
#|hide
show_doc(fbdev.comp.port.Port)

---

### Port

>      Port (port_spec:PortSpec)

*Helper class that provides a standard way to create an ABC using
inheritance.*

In [None]:
#|export
class Port(BasePort):
    def __init__(self, port_spec:PortSpec):
        self._port_spec: PortSpec = port_spec
        self._name: str = port_spec.name
        self._id: str = port_spec.id
        self._port_type: PortType = port_spec.port_type
        self._is_input_port: bool = port_spec.is_input_port
        self._dtype: type = port_spec.dtype
        self._data_validator: Callable[[Any], bool] = port_spec.data_validator
        self._packet: BasePacket = None
        
        self._states = StateCollection()
        self._states._add_state(StateHandler("is_blocked", False)) # If input port, it's blocked if the component is currently getting. If output port, it's blocked if the component is currently putting.
        self._states._add_state(StateHandler("put_awaiting", False))
        self._states._add_state(StateHandler("get_awaiting", False))
        
        self._packet_queue = asyncio.Queue(maxsize=1)
        self._gets_are_waiting_cond = asyncio.Condition()
        self._num_waiting_gets = 0
        self._num_waiting_puts = 0
        
        self._handshakes = asyncio.Queue()
        
        if self.is_input_port:
            self.get = self._get
            self.get_and_consume = self._get_and_consume
        else:
            self.put = self._put
            self.put_value = self._put_value
    
    @property
    def spec(self) -> PortSpec: return self._port_spec
    @property
    def name(self) -> str: return self._name
    @property
    def id(self) -> str: return self._id
    @property
    def port_type(self) -> PortType: return self._port_type
    @property
    def dtype(self) -> type: return self._dtype
    @property
    def is_input_port(self) -> bool: return self._is_input_port
    @property
    def is_output_port(self) -> bool: return not self.is_input_port
    @property
    def data_validator(self) -> Callable[[Any], bool]: return self._data_validator
    @property
    def states(self) -> StateCollection: return self._states
        
    async def __initiate_handshake(self):
        handshake_received_event = asyncio.Event()
        await self._handshakes.put(handshake_received_event)
        await handshake_received_event.wait()
        
    async def __request_handshake(self):
        handshake_received_event = await self._handshakes.get()
        handshake_received_event.set()
        
        
    async def _put(self, packet:BasePacket):
        if not isinstance(packet, BasePacket): raise ValueError(f"`packet` is not of type `{BasePacket.__name__}`.")
        if packet.is_consumed: raise RuntimeError(f"Tried to put already-consumed packet: '{packet.uuid}'.")
        if not self._is_input_port: self.states._is_blocked.set(True)
        self._num_waiting_puts += 1
        self.states._put_awaiting.set(True)
        await self.__initiate_handshake()
        await self._packet_queue.put(packet)
        self._num_waiting_puts -= 1
        if self._num_waiting_puts == 0:
            self.states._put_awaiting.set(False)
            if not self._is_input_port: self.states._is_blocked.set(False)
    
    async def _get(self) -> BasePacket:
        if self._is_input_port: self.states._is_blocked.set(True)
        self.states._get_awaiting.set(True)
        self._num_waiting_gets += 1
        await self.__request_handshake()
        packet = await self._packet_queue.get()
        self._num_waiting_gets -= 1
        if self._num_waiting_gets == 0:
            self.states._get_awaiting.set(False)
            if self._is_input_port: self.states._is_blocked.set(False)
        if packet.is_consumed: raise RuntimeError(f"Got already-consumed packet: '{packet.uuid}'.")
        return packet

In [None]:
port_spec = PortSpec(PortType.OUTPUT, 'out')
port = Port(port_spec)

n = 1000

async def packet_putter():
    for i in range(n):
        await port._put(Packet.get_empty())
async def packet_getter():
    for i in range(n):
        await port._get()
    
await asyncio.gather(packet_putter(), packet_getter());

In [None]:
port_spec = PortSpec(PortType.INPUT, 'in1')
port = Port(port_spec)

tasks = []

for i in range(10):
    packet = Packet(f'datum #{i}')
    async def print_data():
        packet = await port.get()
        data = await packet.consume()
        print(data)
    tasks.append(asyncio.create_task(print_data()))
    
for i in range(10):
    packet = Packet(f'datum #{i}') 
    tasks.append(asyncio.create_task(port._put(packet)))
    
await asyncio.gather(*tasks);

datum #0
datum #1
datum #2
datum #3
datum #4
datum #5
datum #6
datum #7
datum #8
datum #9


In [None]:
port_spec = PortSpec(PortType.INPUT, 'in1')
port = Port(port_spec)

async def put_packet():
    await port._put(Packet(f'data'))
    
asyncio.create_task(put_packet())
asyncio.create_task(put_packet())
await asyncio.sleep(0)
assert port.states.put_awaiting.get()
await port._get()
assert port.states.put_awaiting.get()
await port._get()
assert not port.states.put_awaiting.get()

In [None]:
port_spec = PortSpec(PortType.INPUT, 'in1')
port = Port(port_spec)

async def get_packet():
    await port._get()
    
asyncio.create_task(get_packet())
asyncio.create_task(get_packet())
await asyncio.sleep(0)
assert port.states.get_awaiting.get()
await port._put(Packet(f'data'))
assert port.states.get_awaiting.get()
await port._put(Packet(f'data'))
await asyncio.sleep(0)
assert not port.states.get_awaiting.get()

In [None]:
#|hide
show_doc(fbdev.comp.port.PortCollection)

---

### PortCollection

>      PortCollection (port_spec_collection:PortSpecCollection)

*Initialize self.  See help(type(self)) for accurate signature.*

In [None]:
#|export
class PortCollection:
    def __init__(self, port_spec_collection:PortSpecCollection):
        self._port_spec_collection: PortSpecCollection = port_spec_collection
        self._ports: Dict[str, Port] = {}
        for port_type in PortType:
            setattr(self, port_type.label, AttrContainer({}, obj_name=f"{PortCollection.__name__}.{port_type.label}"))
        for port_spec in port_spec_collection.iter_ports():
            self.__add_port(Port(port_spec))
    
    def __add_port(self, port:Port):
        if not is_valid_name(port.name): raise ValueError(f"Invalid port name '{port.name}'.")
        if port.id in self._ports: raise ValueError(f"Port name '{port.name}' already exists in {self.__class__.__name__}.")
        self._ports[port.id] = port
        
        name_parts = port.name.split('.')
        name_stem = name_parts.pop()
        attr_container = getattr(self, port.port_type.label)
        attr_container_addr = f"{PortSpecCollection.__name__}.{port.port_type.label}"
        for name_part in name_parts:
            attr_container_addr += f".{name_part}"
            if not name_part in attr_container:
                attr_container._set(name_part, AttrContainer({}, obj_name=attr_container_addr))
            attr_container = attr_container[name_part]
        attr_container._set(name_stem, port)
    
    def __getitem__(self, key:PortID) -> Port:
        if key in self._ports: return self._ports[key]
        else: raise KeyError(f"'{key}' does not exist in {self.__class__.__name__}.")
    
    def __iter__(self): return self._ports.__iter__()
    def __len__(self): return self._ports.__len__()
    def __contains__(self, key):
        if type(key) == tuple:
            if type(key[0]) != PortType or type(key[1]) != str or len(key) != 2:
                raise TypeError(f"Key must be a tuple of (PortType, str). Got '{key}'.")
            key = f"{key[0]}.{key[1]}"
        return key in self._ports
    def as_dict(self) -> Dict[str, Port]: return MappingProxyType(self._ports)
    
    def iter_ports(self) -> Iterator[Port]: return self._ports.values().__iter__()

    def __str__(self): return self._port_spec_collection.__str__()
    
    def __repr__(self): return self._port_spec_collection.__repr__()

In [None]:
port_spec_collection = PortSpecCollection(
    PortSpec(PortType.INPUT,'in1'),
    PortSpec(PortType.OUTPUT,'out1', dtype=int),
    PortSpec(PortType.OUTPUT,'output_group.out1', dtype=int),
    PortSpec(PortType.OUTPUT,'output_group.out2', dtype=int),
    PortSpec(PortType.CONFIG,'conf1', dtype=str, default=''),
)

PortCollection(port_spec_collection)

input:
  in1
config:
  conf1
output:
  out1
  output_group:
    out1
    out2