# API CHANGE - store param_groups in var
- Modular creates var with it's param_groups
- passes var to each module sequentially
- if a module has custom param_groups, temporarily set them to var
	- this also has an effect that all child modules will have the same param_groups. That seems like expected behaviord
- defaults are handled naturally
- projections - they set projected param_groups to var with either fake or real projected params. Params need to be stored in the same Tensor objects for states to work (as projection aleady does). A projected module steps with projected var, no initialization logic is needed. Then unproject, empty projected params and return var
- Wrap - can initialize wrapped optimizer to var.param_groups on first step. Need to make sure that if there are no custom parameters, per-parameter settings should be cleared (or just lr? Currently I have just lr, maybe that is better)
- Modules are no longer tied to param_groups, although they still store a per-parameter state.
- Might also add state to var. But having param_groups in var and state in both self and var can be confusing, so it might need a different name, maybe persistent_state? So the only purpose for this is if I wanted to add variables support, but I am not sure if that is needed. If I ever need variables then I add this

- option 3 - same as option 2 but set param_groups object to the module and then delete it after stepping.
	- state_vals don't need params, more consistent with group_vals
	- param_groups is now on self, more consistent with state

In [11]:
from typing import Any
from collections import UserDict
class ParamGroup(UserDict[str, Any]):
    __slots__ = ('defaults', )
    def __init__(self, group: dict[str, Any], defaults: dict[str, Any]):
        super().__init__(group)
        self.defaults = defaults

    def __getitem__(self, k):
        if k in self.data: return self.data[k]
        return self.defaults[k]

    def keys(self): return (self.defaults | self.data).keys()
    def values(self): return (self.defaults | self.data).values()
    def items(self): return (self.defaults | self.data).items()

    def __repr__(self):
        return dict.__repr__(self.defaults | self.data)

z = ParamGroup({"a": 1}, {"b": 2, "a": 10})
dict(z)

{'b': 2, 'a': 1}

In [None]:
z.items() 

dict_items([('b', 2), ('a', 1)])

In [None]:
from collections import defaultdict
from typing import Any
from collections.abc import Iterable, Sequence, Mapping
from abc import ABC, abstractmethod
from torchzero.utils.optimizer import ParamFilter, ListLike, _param_filter
from torchzero.utils import TensorList
import torch
Params = Iterable[torch.Tensor | tuple[str, torch.Tensor] | Sequence | Mapping[str, Any]]

def _make_param_groups(params) -> list[dict[str, Any]]: ...
def maybe_chain(modules) -> "Module": ...

class Module(ABC):
    def __init__(self, defaults: dict[str, Any] | None = None):
        if defaults is None: defaults = {}
        self.defaults: dict[str, Any] = defaults

        # this is now temporarily set before stepping
        self.param_groups: list[dict[str, Any]] = []

        self.state: defaultdict[Any, dict[str, Any]] = defaultdict(dict)
        self.global_state: dict[str, Any] = {}

        self.children: dict[str, Module] = {}

        self._custom_param_groups: list[dict[str, Any]] = []

        # initialization logic is not needed anymore?
        # self._initialized = False

    # not needed methods
    # def _initialize(self, params: Params): ...
    def set_params(self, params: Params):
        self._custom_param_groups = _make_param_groups(params)

    def get_params(self, mode: ParamFilter = 'requires_grad', cls: type[ListLike] = TensorList) -> ListLike:
        #if not self._initialized: raise RuntimeError(f"Calling get_params on {self} which is not initialized")

        return cls(p for g in self.param_groups for p in g['params'] if _param_filter(p, mode)) # type:ignore

    def set_child(self, key: str, module: "Module | Iterable[Module]"):
        # not needed
        # if self._initialized: raise RuntimeError(f'{self} is already initialized, but trying to set `{key}` child to {module}')

        # from .chain import maybe_chain
        self.children[key] = maybe_chain(module)
