# Node Execution Context

The `NodeExecutionContext` is passed to node execution functions and provides
methods for packet operations and epoch control.

In [None]:
#|default_exp context

In [None]:
#|hide
from nblite import nbl_export, show_doc; nbl_export();

In [None]:
#|export
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from datetime import datetime

from netrun_sim import Packet, NetAction

from netrun.errors import EpochCancelled
from netrun.deferred import DeferredPacket, DeferredActionQueue

if TYPE_CHECKING:
    from netrun.net import Net


class NodeExecutionContext:
    """
    Context passed to node execution functions.

    Provides access to packet operations, retry information, and epoch control.
    All packet operations respect the defer_net_actions setting.
    """

    def __init__(
        self,
        net: "Net",
        epoch_id: str,
        node_name: str,
        defer_net_actions: bool = False,
        retry_count: int = 0,
        retry_timestamps: Optional[List[datetime]] = None,
        retry_exceptions: Optional[List[Exception]] = None,
    ):
        """
        Initialize the execution context.

        Args:
            net: The Net instance
            epoch_id: ID of the current epoch
            node_name: Name of the node being executed
            defer_net_actions: Whether to buffer actions until successful completion
            retry_count: Current retry attempt (0 = first attempt)
            retry_timestamps: Timestamps of previous retry attempts
            retry_exceptions: Exceptions from previous retries
        """
        self._net = net
        self._epoch_id = epoch_id
        self._node_name = node_name
        self._defer_net_actions = defer_net_actions
        self._retry_count = retry_count
        self._retry_timestamps = retry_timestamps or []
        self._retry_exceptions = retry_exceptions or []

        # Deferred action queue (only used if defer_net_actions=True)
        self._deferred_queue: Optional[DeferredActionQueue] = (
            DeferredActionQueue() if defer_net_actions else None
        )

        # Track consumed values for this execution (for failure context)
        self._consumed_values: dict[str, Any] = {}

    # -------------------------------------------------------------------------
    # Properties
    # -------------------------------------------------------------------------

    @property
    def epoch_id(self) -> str:
        """The ID of the current epoch."""
        return self._epoch_id

    @property
    def node_name(self) -> str:
        """The name of the node being executed."""
        return self._node_name

    @property
    def retry_count(self) -> int:
        """Current retry attempt (0 = first attempt)."""
        return self._retry_count

    @property
    def retry_timestamps(self) -> List[datetime]:
        """Timestamps of previous retry attempts."""
        return self._retry_timestamps.copy()

    @property
    def retry_exceptions(self) -> List[Exception]:
        """Exceptions from previous retries."""
        return self._retry_exceptions.copy()

    # -------------------------------------------------------------------------
    # Packet Operations (Sync)
    # -------------------------------------------------------------------------

    def create_packet(self, value: Any) -> Union[Packet, DeferredPacket]:
        """
        Create a new packet with a direct value.

        If defer_net_actions=True, returns a DeferredPacket.
        Otherwise, creates the packet immediately in NetSim.
        """
        if self._defer_net_actions and self._deferred_queue is not None:
            return self._deferred_queue.create_packet(value)

        # Immediate mode: create packet in NetSim
        action = NetAction.create_packet(self._epoch_id)
        response_data, _ = self._net._sim.do_action(action)

        # Get the created packet ID from the response data
        packet_id = response_data.packet_id

        # Store the value
        self._net._value_store.store_value(packet_id, value)

        return self._net._sim.get_packet(packet_id)

    def create_packet_from_value_func(self, func: Callable[[], Any]) -> Union[Packet, DeferredPacket]:
        """
        Create a new packet with a lazy value function.

        The function is called when the packet is consumed.

        If defer_net_actions=True, returns a DeferredPacket.
        Otherwise, creates the packet immediately in NetSim.
        """
        if self._defer_net_actions and self._deferred_queue is not None:
            return self._deferred_queue.create_packet_from_func(func)

        # Immediate mode: create packet in NetSim
        action = NetAction.create_packet(self._epoch_id)
        response_data, _ = self._net._sim.do_action(action)

        # Get the created packet ID from the response data
        packet_id = response_data.packet_id

        # Store the value function
        self._net._value_store.store_value_func(packet_id, func)

        return self._net._sim.get_packet(packet_id)

    def consume_packet(self, packet: Union[Packet, DeferredPacket]) -> Any:
        """
        Consume a packet and return its value.

        Removes the packet from the network and returns the stored value.
        If the packet has a value function, it is called.
        """
        if isinstance(packet, DeferredPacket):
            if not packet.is_resolved:
                raise ValueError("Cannot consume an unresolved deferred packet")
            packet_id = packet.id
        else:
            packet_id = packet.id

        # Get the value (this calls value functions if needed)
        value = self._net._value_store.consume(packet_id)
        self._consumed_values[packet_id] = value

        if self._defer_net_actions and self._deferred_queue is not None:
            # Defer the consume action
            self._deferred_queue.consume_packet(packet, value)
        else:
            # Immediate mode: consume in NetSim
            action = NetAction.consume_packet(packet_id)
            self._net._sim.do_action(action)

        return value

    def load_output_port(self, port_name: str, packet: Union[Packet, DeferredPacket]) -> None:
        """
        Load a packet into an output port.

        The packet must have been created in this epoch.
        """
        if self._defer_net_actions and self._deferred_queue is not None:
            self._deferred_queue.load_output_port(port_name, packet)
            return

        # Immediate mode
        if isinstance(packet, DeferredPacket):
            if not packet.is_resolved:
                raise ValueError("Cannot load an unresolved deferred packet")
            packet_id = packet.id
        else:
            packet_id = packet.id

        action = NetAction.load_packet_into_output_port(packet_id, port_name)
        self._net._sim.do_action(action)

    def send_output_salvo(self, salvo_condition_name: str) -> None:
        """
        Send packets from output ports via a salvo condition.

        The salvo condition must be satisfied for sending to succeed.
        """
        if self._defer_net_actions and self._deferred_queue is not None:
            self._deferred_queue.send_output_salvo(salvo_condition_name)
            return

        # Immediate mode
        action = NetAction.send_output_salvo(self._epoch_id, salvo_condition_name)
        self._net._sim.do_action(action)

    # -------------------------------------------------------------------------
    # Packet Operations (Async)
    # -------------------------------------------------------------------------

    async def async_consume_packet(self, packet: Union[Packet, DeferredPacket]) -> Any:
        """
        Async version of consume_packet.

        Supports async value functions.
        """
        if isinstance(packet, DeferredPacket):
            if not packet.is_resolved:
                raise ValueError("Cannot consume an unresolved deferred packet")
            packet_id = packet.id
        else:
            packet_id = packet.id

        # Get the value (async to support async value functions)
        value = await self._net._value_store.async_consume(packet_id)
        self._consumed_values[packet_id] = value

        if self._defer_net_actions and self._deferred_queue is not None:
            self._deferred_queue.consume_packet(packet, value)
        else:
            action = NetAction.consume_packet(packet_id)
            self._net._sim.do_action(action)

        return value

    # -------------------------------------------------------------------------
    # Epoch Control
    # -------------------------------------------------------------------------

    def cancel_epoch(self) -> None:
        """
        Cancel the current epoch.

        Raises EpochCancelled exception which should not be caught by the node.
        """
        raise EpochCancelled(self._node_name, self._epoch_id)

    # -------------------------------------------------------------------------
    # Internal Methods
    # -------------------------------------------------------------------------

    def _get_deferred_queue(self) -> Optional[DeferredActionQueue]:
        """Get the deferred action queue (for internal use)."""
        return self._deferred_queue

    def _get_consumed_values(self) -> dict[str, Any]:
        """Get the consumed values (for failure context)."""
        return self._consumed_values.copy()

## Node Failure Context

The `NodeFailureContext` is passed to the `exec_failed_node_func` callback
after a failed execution attempt.

In [None]:
#|export
class NodeFailureContext:
    """
    Context passed to node failure handlers after execution failure.

    Provides access to retry information, input packets, and consumed values.
    Does not provide packet operations (execution has already failed).
    """

    def __init__(
        self,
        epoch_id: str,
        node_name: str,
        retry_count: int,
        retry_timestamps: List[datetime],
        retry_exceptions: List[Exception],
        input_salvo: dict[str, list[Packet]],
        packet_values: dict[str, Any],
        exception: Exception,
    ):
        """
        Initialize the failure context.

        Args:
            epoch_id: ID of the failed epoch
            node_name: Name of the node that failed
            retry_count: Current retry attempt (0 = first attempt)
            retry_timestamps: Timestamps of all retry attempts including current
            retry_exceptions: Exceptions from all retries including current
            input_salvo: The input packets that triggered this epoch
            packet_values: Values that were consumed during execution
            exception: The exception that caused the failure
        """
        self._epoch_id = epoch_id
        self._node_name = node_name
        self._retry_count = retry_count
        self._retry_timestamps = retry_timestamps
        self._retry_exceptions = retry_exceptions
        self._input_salvo = input_salvo
        self._packet_values = packet_values
        self._exception = exception

    @property
    def epoch_id(self) -> str:
        """The ID of the failed epoch."""
        return self._epoch_id

    @property
    def node_name(self) -> str:
        """The name of the node that failed."""
        return self._node_name

    @property
    def retry_count(self) -> int:
        """Current retry attempt (0 = first attempt)."""
        return self._retry_count

    @property
    def retry_timestamps(self) -> List[datetime]:
        """Timestamps of all retry attempts including current."""
        return self._retry_timestamps.copy()

    @property
    def retry_exceptions(self) -> List[Exception]:
        """Exceptions from all retries including current."""
        return self._retry_exceptions.copy()

    @property
    def input_salvo(self) -> dict[str, list[Packet]]:
        """The input packets that triggered this epoch."""
        return self._input_salvo.copy()

    @property
    def packet_values(self) -> dict[str, Any]:
        """Values that were consumed during execution."""
        return self._packet_values.copy()

    @property
    def exception(self) -> Exception:
        """The exception that caused the failure."""
        return self._exception