# Packet Value Storage

The `PacketValueStore` manages packet values separately from `netrun-sim`'s packet tracking.
It supports direct values, lazy value functions, and optional persistence.

In [None]:
#|default_exp storage

In [None]:
#|hide
from nblite import nbl_export, show_doc; nbl_export();

In [None]:
#|export
from typing import Any, Callable, Optional, Union
from dataclasses import dataclass
from collections import OrderedDict
from pathlib import Path
import pickle
import asyncio

from netrun.errors import ValueFunctionFailed


@dataclass
class StoredValue:
    """A stored packet value, either direct or via a value function."""
    value: Any = None
    value_func: Optional[Callable[[], Any]] = None
    is_value_func: bool = False

    def get_value(self) -> Any:
        """Get the value, calling the value function if necessary."""
        if self.is_value_func and self.value_func is not None:
            return self.value_func()
        return self.value

    async def async_get_value(self) -> Any:
        """Get the value, awaiting async value functions if necessary."""
        if self.is_value_func and self.value_func is not None:
            result = self.value_func()
            if asyncio.iscoroutine(result):
                return await result
            return result
        return self.value


class PacketValueStore:
    """
    Manages packet values for the netrun runtime.

    Handles:
    - Direct value storage/retrieval by packet ID
    - Lazy value functions (called on consumption)
    - Consumed packet storage with configurable limits
    - Optional file-based persistence
    """

    def __init__(
        self,
        consumed_storage: bool = False,
        consumed_storage_limit: Optional[int] = None,
        storage_path: Optional[Union[str, Path]] = None,
    ):
        """
        Initialize the packet value store.

        Args:
            consumed_storage: Whether to keep values after consumption
            consumed_storage_limit: Max consumed values to keep (None = unlimited)
            storage_path: Optional path for file-based storage
        """
        self._values: dict[str, StoredValue] = {}
        self._consumed_values: OrderedDict[str, Any] = OrderedDict()
        self._consumed_storage = consumed_storage
        self._consumed_storage_limit = consumed_storage_limit
        self._storage_path = Path(storage_path) if storage_path else None

        if self._storage_path:
            self._storage_path.mkdir(parents=True, exist_ok=True)

    def store_value(self, packet_id: str, value: Any) -> None:
        """Store a direct value for a packet."""
        self._values[packet_id] = StoredValue(value=value, is_value_func=False)

        if self._storage_path:
            self._persist_to_file(packet_id, value)

    def store_value_func(self, packet_id: str, func: Callable[[], Any]) -> None:
        """Store a value function for a packet (called lazily on consumption)."""
        self._values[packet_id] = StoredValue(value_func=func, is_value_func=True)

    def get_value(self, packet_id: str) -> Any:
        """
        Get a packet's value (consuming it from the store).

        For value functions, this calls the function.
        Raises KeyError if packet not found.
        """
        if packet_id not in self._values:
            # Check consumed storage
            if packet_id in self._consumed_values:
                return self._consumed_values[packet_id]
            # Check file storage
            if self._storage_path:
                value = self._load_from_file(packet_id)
                if value is not None:
                    return value
            raise KeyError(f"Packet {packet_id} not found in value store")

        stored = self._values[packet_id]
        try:
            value = stored.get_value()
        except Exception as e:
            raise ValueFunctionFailed(packet_id, e) from e

        return value

    async def async_get_value(self, packet_id: str) -> Any:
        """Async version of get_value, supporting async value functions."""
        if packet_id not in self._values:
            if packet_id in self._consumed_values:
                return self._consumed_values[packet_id]
            if self._storage_path:
                value = self._load_from_file(packet_id)
                if value is not None:
                    return value
            raise KeyError(f"Packet {packet_id} not found in value store")

        stored = self._values[packet_id]
        try:
            value = await stored.async_get_value()
        except Exception as e:
            raise ValueFunctionFailed(packet_id, e) from e

        return value

    def consume(self, packet_id: str) -> Any:
        """
        Consume a packet's value, removing it from active storage.

        If consumed_storage is enabled, the value is kept in consumed storage.
        """
        value = self.get_value(packet_id)
        self._remove_from_active(packet_id)

        if self._consumed_storage:
            self._add_to_consumed(packet_id, value)

        return value

    async def async_consume(self, packet_id: str) -> Any:
        """Async version of consume."""
        value = await self.async_get_value(packet_id)
        self._remove_from_active(packet_id)

        if self._consumed_storage:
            self._add_to_consumed(packet_id, value)

        return value

    def unconsume(self, packet_id: str, value: Any) -> None:
        """
        Restore a consumed packet's value (for retry scenarios).

        This moves the value back to active storage.
        """
        self._values[packet_id] = StoredValue(value=value, is_value_func=False)

        # Remove from consumed if present
        if packet_id in self._consumed_values:
            del self._consumed_values[packet_id]

    def has_value(self, packet_id: str) -> bool:
        """Check if a packet has a stored value."""
        return packet_id in self._values

    def remove(self, packet_id: str) -> None:
        """Remove a packet's value from the store entirely."""
        self._remove_from_active(packet_id)
        if packet_id in self._consumed_values:
            del self._consumed_values[packet_id]

    def get_consumed_value(self, packet_id: str) -> Optional[Any]:
        """Get a value from consumed storage (doesn't remove it)."""
        return self._consumed_values.get(packet_id)

    def _remove_from_active(self, packet_id: str) -> None:
        """Remove from active storage."""
        if packet_id in self._values:
            del self._values[packet_id]

    def _add_to_consumed(self, packet_id: str, value: Any) -> None:
        """Add to consumed storage, respecting limits."""
        self._consumed_values[packet_id] = value

        # Enforce limit by removing oldest entries
        if self._consumed_storage_limit is not None:
            while len(self._consumed_values) > self._consumed_storage_limit:
                self._consumed_values.popitem(last=False)

    def _persist_to_file(self, packet_id: str, value: Any) -> None:
        """Persist a value to file storage."""
        if self._storage_path:
            file_path = self._storage_path / f"{packet_id}.pkl"
            with open(file_path, 'wb') as f:
                pickle.dump(value, f)

    def _load_from_file(self, packet_id: str) -> Optional[Any]:
        """Load a value from file storage."""
        if self._storage_path:
            file_path = self._storage_path / f"{packet_id}.pkl"
            if file_path.exists():
                with open(file_path, 'rb') as f:
                    return pickle.load(f)
        return None