# Node Factories

Node factories are Python modules that generate node specifications and execution functions.
This enables reusable node templates that can be configured with different parameters.

A node factory module must contain:
- `get_node_spec(**args) -> dict`: Returns kwargs for netrun_sim.Node
- `get_node_funcs(**args) -> tuple`: Returns (exec_func, start_func, stop_func, failed_func)

In [None]:
#|default_exp factories

In [None]:
#|export
from dataclasses import dataclass
from typing import Any, Callable, Optional, Dict, Tuple, Protocol, runtime_checkable

from netrun.dsl import resolve_import_path

In [None]:
#|export
@runtime_checkable
class NodeFactory(Protocol):
    """
    Protocol for node factory modules.

    A node factory module must have these two functions:
    - get_node_spec(**args) -> dict
    - get_node_funcs(**args) -> tuple
    """

    def get_node_spec(self, **args) -> Dict[str, Any]:
        """
        Returns a dictionary that can be passed as kwargs to netrun_sim.Node.
        Should be lightweight and suitable for UI introspection.
        """
        ...

    def get_node_funcs(self, **args) -> Tuple[
        Optional[Callable],  # exec_node_func
        Optional[Callable],  # start_node_func
        Optional[Callable],  # stop_node_func
        Optional[Callable],  # exec_failed_node_func
    ]:
        """
        Returns (exec_node_func, start_node_func, stop_node_func, exec_failed_node_func).
        Any can be None.
        """
        ...

In [None]:
#|export
@dataclass
class NodeFactoryResult:
    """
    Result of calling a node factory.

    Contains the node specification (kwargs for Node) and execution functions.

    Attributes:
        node_spec: Dictionary of kwargs for netrun_sim.Node constructor
            (in_ports, out_ports, in_salvo_conditions, out_salvo_conditions).
        exec_node_func: Main execution function for the node.
        start_node_func: Called when the net starts.
        stop_node_func: Called when the net stops.
        exec_failed_node_func: Called after failed execution attempts.
        factory_path: Import path of the factory (for serialization).
        factory_args: Arguments passed to the factory (for serialization).
    """
    node_spec: Dict[str, Any]
    exec_node_func: Optional[Callable] = None
    start_node_func: Optional[Callable] = None
    stop_node_func: Optional[Callable] = None
    exec_failed_node_func: Optional[Callable] = None

    # Store factory info for serialization
    factory_path: Optional[str] = None
    factory_args: Optional[Dict[str, Any]] = None

In [None]:
#|export
class FactoryError(Exception):
    """Base class for factory-related errors."""
    pass


class FactoryNotFoundError(FactoryError):
    """Raised when a factory module or function cannot be found."""
    pass


class InvalidFactoryError(FactoryError):
    """Raised when a factory module doesn't have required functions."""
    pass

In [None]:
#|export
def load_factory(
    factory_path: str,
    **factory_args: Any
) -> NodeFactoryResult:
    """
    Load a node factory module and call its functions.

    Args:
        factory_path: Dotted import path to the factory module.
            Can be either:
            - A module path (e.g., "my_module.my_factory") - module must have
              get_node_spec and get_node_funcs functions
            - A function path (e.g., "my_module.create_node") - function must
              return a NodeFactoryResult or a tuple of (spec, funcs)
        **factory_args: Arguments to pass to the factory functions

    Returns:
        NodeFactoryResult with node spec and execution functions

    Raises:
        FactoryNotFoundError: If the factory cannot be imported
        InvalidFactoryError: If the factory doesn't have required functions
    """
    try:
        factory = resolve_import_path(factory_path)
    except ImportError as e:
        raise FactoryNotFoundError(f"Cannot import factory '{factory_path}': {e}")

    # Check if it's a callable (factory function)
    if callable(factory) and not hasattr(factory, 'get_node_spec'):
        return _call_factory_function(factory, factory_path, factory_args)

    # Otherwise treat it as a module with get_node_spec and get_node_funcs
    return _call_factory_module(factory, factory_path, factory_args)


def _call_factory_function(
    factory_func: Callable,
    factory_path: str,
    factory_args: Dict[str, Any]
) -> NodeFactoryResult:
    """Call a factory function and return a NodeFactoryResult."""
    result = factory_func(**factory_args)

    # If it's already a NodeFactoryResult, add factory info and return
    if isinstance(result, NodeFactoryResult):
        result.factory_path = factory_path
        result.factory_args = factory_args
        return result

    # If it's a tuple, unpack it
    if isinstance(result, tuple):
        if len(result) == 2:
            spec, funcs = result
            if isinstance(funcs, tuple) and len(funcs) == 4:
                return NodeFactoryResult(
                    node_spec=spec,
                    exec_node_func=funcs[0],
                    start_node_func=funcs[1],
                    stop_node_func=funcs[2],
                    exec_failed_node_func=funcs[3],
                    factory_path=factory_path,
                    factory_args=factory_args,
                )
            else:
                raise InvalidFactoryError(
                    f"Factory function '{factory_path}' returned invalid funcs tuple. "
                    f"Expected 4 elements, got {len(funcs) if isinstance(funcs, tuple) else type(funcs)}"
                )
        elif len(result) == 5:
            # (spec, exec, start, stop, failed)
            return NodeFactoryResult(
                node_spec=result[0],
                exec_node_func=result[1],
                start_node_func=result[2],
                stop_node_func=result[3],
                exec_failed_node_func=result[4],
                factory_path=factory_path,
                factory_args=factory_args,
            )
        else:
            raise InvalidFactoryError(
                f"Factory function '{factory_path}' returned invalid tuple. "
                f"Expected 2 or 5 elements, got {len(result)}"
            )

    raise InvalidFactoryError(
        f"Factory function '{factory_path}' must return a NodeFactoryResult "
        f"or tuple, got {type(result)}"
    )


def _call_factory_module(
    factory_module: Any,
    factory_path: str,
    factory_args: Dict[str, Any]
) -> NodeFactoryResult:
    """Call get_node_spec and get_node_funcs on a factory module."""
    # Check for required functions
    if not hasattr(factory_module, 'get_node_spec'):
        raise InvalidFactoryError(
            f"Factory module '{factory_path}' missing required function 'get_node_spec'"
        )
    if not hasattr(factory_module, 'get_node_funcs'):
        raise InvalidFactoryError(
            f"Factory module '{factory_path}' missing required function 'get_node_funcs'"
        )

    # Call the functions
    node_spec = factory_module.get_node_spec(**factory_args)
    funcs = factory_module.get_node_funcs(**factory_args)

    if not isinstance(node_spec, dict):
        raise InvalidFactoryError(
            f"Factory '{factory_path}' get_node_spec must return a dict, "
            f"got {type(node_spec)}"
        )

    if not isinstance(funcs, tuple) or len(funcs) != 4:
        raise InvalidFactoryError(
            f"Factory '{factory_path}' get_node_funcs must return a 4-tuple, "
            f"got {type(funcs)} with length {len(funcs) if isinstance(funcs, tuple) else 'N/A'}"
        )

    return NodeFactoryResult(
        node_spec=node_spec,
        exec_node_func=funcs[0],
        start_node_func=funcs[1],
        stop_node_func=funcs[2],
        exec_failed_node_func=funcs[3],
        factory_path=factory_path,
        factory_args=factory_args,
    )

In [None]:
#|export
def create_node_from_factory(
    factory_path: str,
    **factory_args: Any
) -> Tuple[Any, NodeFactoryResult]:
    """
    Create a Node from a factory.

    This is a convenience function that loads a factory, creates the Node,
    and returns both the Node and the factory result (for setting up exec functions).

    Args:
        factory_path: Dotted import path to the factory
        **factory_args: Arguments to pass to the factory

    Returns:
        Tuple of (Node, NodeFactoryResult)
    """
    from netrun_sim import Node

    result = load_factory(factory_path, **factory_args)
    node = Node(**result.node_spec)
    return node, result

In [None]:
#|export
def is_json_serializable(value: Any) -> bool:
    """Check if a value is JSON serializable."""
    import json
    try:
        json.dumps(value)
        return True
    except (TypeError, ValueError):
        return False


def validate_factory_args(factory_args: Dict[str, Any]) -> None:
    """
    Validate that factory arguments are serializable.

    Factory args must be either:
    - JSON-serializable values
    - Import path strings (for callables)

    Raises:
        ValueError: If any argument is not serializable
    """
    for key, value in factory_args.items():
        if callable(value):
            raise ValueError(
                f"Factory argument '{key}' is a callable. "
                f"Use an import path string instead (e.g., 'my_module.my_func')"
            )
        if not is_json_serializable(value):
            raise ValueError(
                f"Factory argument '{key}' is not JSON-serializable: {type(value)}"
            )