Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add features to ZenStore #569

Merged
merged 8 commits into from Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/changes.rst
Expand Up @@ -68,6 +68,7 @@ Improvements
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.make_custom_builds_fn` now accept a `zen_exclude` field for excluding parameters from auto-population, either by name or by pattern. See :pull:`558`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.just` can now configure static methods. Previously the incorrect ``_target_`` would be resolved. See :pull:`566`
- Adds formal support for Python 3.12. See :pull:`555`
- Several new methods were added to :class:`~hydra_zen.ZenStore`, including the abilities to copy, update, and merge stores. As well as remap the groups of a store's entries and delete individual entries. See :pull:`569`


.. _v0.11.0:
Expand Down
6 changes: 6 additions & 0 deletions docs/source/generated/hydra_zen.ZenStore.rst
Expand Up @@ -12,8 +12,14 @@ hydra\_zen.ZenStore
.. automethod:: __call__
.. automethod:: __getitem__
.. automethod:: add_to_hydra_store
.. automethod:: copy
.. automethod:: copy_with_mapped_groups
.. automethod:: get_entry
.. automethod:: delete_entry
.. automethod:: update
.. automethod:: merge
.. automethod:: has_enqueued
.. automethod:: enqueue_all
.. automethod:: __iter__
.. automethod:: __eq__

Expand Down
279 changes: 260 additions & 19 deletions src/hydra_zen/wrapper/_implementations.py
Expand Up @@ -3,14 +3,14 @@
# pyright: strict, reportUnnecessaryTypeIgnoreComment = true, reportUnnecessaryIsInstance = false

import warnings
from collections import defaultdict, deque
from collections import defaultdict
from copy import deepcopy
from functools import wraps
from inspect import Parameter, signature
from typing import (
Any,
Callable,
DefaultDict,
Deque,
Dict,
FrozenSet,
Generator,
Expand Down Expand Up @@ -1271,6 +1271,12 @@ def func(x, y): ...
schema: <none>
has_root: true
_target_: __main__.Profile

**Manipulating and updating a store**

A store can be copied, updated, and merged. Its entries can have their groups
remapped, and individual entries can be deleted. See the docs for the corresponding
methods for details and examples.
"""

__slots__ = (
Expand Down Expand Up @@ -1339,7 +1345,7 @@ def __init__(
# created via the 'self-partialing' process
self._internal_repo: Dict[Tuple[GroupName, NodeName], StoreEntry] = {}
# Internal repo entries that have yet to be added to Hydra's config store
self._queue: Deque[StoreEntry] = deque([])
self._queue: Set[Tuple[GroupName, NodeName]] = set()

self._deferred_to_config = deferred_to_config
self._deferred_store = deferred_hydra_store
Expand Down Expand Up @@ -1409,8 +1415,7 @@ def __call__(
def __call__(self: Self, __target: Optional[F] = None, **kw: Any) -> Union[F, Self]:
"""__call__(target : Optional[T] = None, /, name: NodeName | Callable[[Any], NodeName]] = ..., group: GroupName | Callable[[T], GroupName]] = None, package: Optional[str | Callable[[T], str]]] | None], provider: Optional[str], to_config: Callable[[T], Node] = ..., **to_config_kw: Any) -> T | ZenStore

The interface to an initialized store. Can be used to store a config or to
:ref:`customize the default values <self-partial>` of the store.
Store a config or :ref:`customize the default values <self-partial>` of the store.

Parameters
----------
Expand Down Expand Up @@ -1544,20 +1549,118 @@ def __call__(self: Self, __target: Optional[F] = None, **kw: Any) -> Union[F, Se
node=node,
)

if not self._overwrite_ok and (_group, _name) in self._internal_repo:
raise ValueError(
f"(name={entry['name']} group={entry['group']}): "
f"Store entry already exists. Use a store initialized "
f"with `ZenStore(overwrite_ok=True)` to overwrite config store "
f"entries."
)
self._internal_repo[_group, _name] = entry
self._queue.append(entry)

if not self._deferred_store:
self.add_to_hydra_store()
self._set_entry(entry, overwrite=self._overwrite_ok)
return cast(Union[F, Self], __target)

def copy(self: Self, store_name: Optional[str] = None) -> Self:
"""Returns a copy of the store with the same overridden defaults.

Parameters
----------
store_name : str | None, optional (default=None)

Returns
-------
ZenStore

Examples
--------
>>> from hydra_zen import ZenStore
>>> s1 = ZenStore()(group="G")
>>> s1({}, name="a")
>>> s2 = s1.copy()
>>> s2({}, name="b")
>>> s1
s1
{'G': ['a']}
>>> s2
s1_copy
{'G': ['a', 'b']}
"""
cp = deepcopy(self)

cp.name = store_name if store_name is not None else self.name + "_copy"
return cp

def copy_with_mapped_groups(
self: Self,
old_group_to_new_group: Union[
Mapping[GroupName, GroupName], Callable[[GroupName], GroupName]
],
*,
store_name: Optional[str] = None,
overwrite_ok: Optional[bool] = None,
) -> Self:
"""Create a copy of a store whose entries' groups have been updated according to the provided mapping.

Parameters
----------
old_group_to_new_group : Mapping[GroupName, GroupName] | Callable[[GroupName], GroupName]
A mapping or callable that transforms an old group name to a new one.
Groups in the store that are not included in the mapping are unaffected.

A `GroupName` is `str | None`.

store_name : Optional[None]
If specified, the name of the new store.

overwrite_ok : Optional[bool]:
If specified, determines if the mapping can overwrite existing store
entries. Otherwise, defers to `ZenStore(overwrite_ok)`.

Returns
-------
new_store
A copy of `self` with remapped groups.

Examples
--------
>>> from hydra_zen import ZenStore

Creating an initial store

>>> s1 = ZenStore("s1")
>>> s1({}, group=None, name="a")
>>> s1({}, group="A/1", name="b")
>>> s1({}, group="A/2", name="c")
>>> s1
s1
{None: ['a'], 'A/1': ['b'], 'A/2': ['c']}

Replacing group "A/1" with "B", via a mapping

>>> s2 = s1.copy_with_mapped_groups({"A/1": "B"}, store_name="s2")
>>> s2
s2
{None: ['a'], 'A/2': ['c'], 'B': ['b']}

Placing all entries under group "A/" within a new inner group "p", via a
function

>>> s3 = s1.copy_with_mapped_groups(
... lambda g: g + "/p" if g and g.startswith("A/") else g, store_name="s3"
... )
>>> s3
s3
{None: ['a'], 'A/1/p': ['b'], 'A/2/p': ['c']}
"""
overwrite = overwrite_ok if overwrite_ok is not None else self._overwrite_ok

map_fn: Callable[[GroupName], GroupName] = (
(lambda x: old_group_to_new_group.get(x, x))
if isinstance(old_group_to_new_group, Mapping)
else old_group_to_new_group
)

copy = self.copy(store_name)
for (group, name), entry in tuple(copy._internal_repo.items()):
new_group = map_fn(group)
if new_group != group:
del copy[group, name]
entry["group"] = new_group
copy._set_entry(entry, overwrite=overwrite)
return copy

@property
def groups(self) -> Sequence[GroupName]:
"""Returns a sorted list of the groups registered with this store"""
Expand All @@ -1570,6 +1673,27 @@ def groups(self) -> Sequence[GroupName]:
no_none = cast(Set[str], set_)
return sorted(no_none)

def enqueue_all(self) -> None:
"""Add all of the store's entries to the queue to be added to hydra's store.

Examples
--------
>>> from hydra_zen import ZenStore
>>> store = ZenStore(deferred_hydra_store=True)
>>> store({"a": 1}, name)
>>> store.has_enqueued()
True

>>> store.add_to_hydra_store()
>>> store.has_enqueued()
False

>>> store.enqueue_all()
>>> store.has_enqueued()
True
"""
self._queue.update(self._internal_repo.keys())

def has_enqueued(self) -> bool:
"""`True` if this store has entries that have not yet been added to
Hydra's config store.
Expand Down Expand Up @@ -1600,6 +1724,84 @@ def __bool__(self) -> bool:
not they have been added to Hydra's config store"""
return bool(self._internal_repo)

def __len__(self) -> int:
return len(self._internal_repo)

def update(self, __other: "ZenStore") -> None:
"""Updates the store inplace with redundant entries being overwritten.

Can also be applied via the `|=` in-place operator.

Examples
--------
>>> from hydra_zen import ZenStore
>>> def f(): ...
>>> def g(): ...
>>> s1 = ZenStore("s1")
>>> s2 = ZenStore("s2")
>>> s1(f) # store f in s1
>>> s2(g) # store g in s2

>>> s1.update(s2)
>>> s1 # s1 now has entries for both f and g
s1
{None: ['f', 'g']}

Alternatively, the `|=` operator can be used to update a store inplace.

>>> s3 = ZenStore("s3")
>>> s3 |= s2
>>> s3
s3
{None: ['g']}
"""
self._internal_repo.update(deepcopy(__other._internal_repo))
self._queue.update(__other._queue)
if not self._deferred_store:
self.add_to_hydra_store()
return

def merge(
self: Self, __other: "ZenStore", store_name: Optional[str] = None
) -> Self:
"""Create a new store by merging two stores.

The new store's default settings will reflect those of `self` in
`self.merge(other)`. This can also be applied via the `|` operator.

Examples
--------
>>> from hydra_zen import ZenStore
>>> def f(): ...
>>> def g(): ...
>>> s1 = ZenStore("s1")
>>> s2 = ZenStore("s2")
>>> s1(f) # store f in s1
>>> s2(g) # store g in s2

>>> s3 = s1.merge(s2)
>>> s3
s1_copy
{None: ['f', 'g']}

Alternatively, the `|` operator can be used to merge stores.

>>> s4 = s1 | s2
>>> s4
s1_copy
{None: ['f', 'g']}
"""
cp = self.copy(store_name)
cp.update(__other)
return cp

def __or__(self: Self, other: "ZenStore") -> Self:
return self.merge(other)

def __ior__(self: Self, other: "ZenStore") -> Self:
self.update(other)
return self

@overload
def __getitem__(self, key: Tuple[GroupName, NodeName]) -> Node:
...
Expand Down Expand Up @@ -1661,10 +1863,31 @@ def __getitem__(self, key: Union[GroupName, Tuple[GroupName, NodeName]]) -> Node
}
return _resolve_node(self._internal_repo[key], copy=False)["node"]

def __delitem__(self, key: Tuple[GroupName, NodeName]) -> None:
del self._internal_repo[key]
self._queue.discard(key)

def delete_entry(self, group: GroupName, name: NodeName) -> None:
del self[group, name]

def get_entry(self, group: GroupName, name: NodeName) -> StoreEntry:
"""Access a store entry, which is a mapping that specifies the entry's
name, group, package, provider, and node.

Parameters
----------
group : str | None
name : str

Returns
-------
dict
- name: NodeName
- group: GroupName
- package: Optional[str]
- provider: Optional[str]
- node: ConfigType

Notes
-----
Mutating the returned mapping will not affect the store's internal entry.
Expand All @@ -1684,6 +1907,22 @@ def get_entry(self, group: GroupName, name: NodeName) -> StoreEntry:
"""
return _resolve_node(self._internal_repo[(group, name)], copy=True)

def _set_entry(self, __entry: StoreEntry, overwrite: bool) -> None:
_group = __entry["group"]
_name = __entry["name"]
if not overwrite and (_group, _name) in self._internal_repo:
raise ValueError(
f"(name={__entry['name']} group={__entry['group']}): "
f"Store entry already exists. Use a store initialized "
f"with `ZenStore(overwrite_ok=True)` to overwrite config store "
f"entries."
)
self._internal_repo[_group, _name] = __entry
self._queue.add((_group, _name))
rsokl marked this conversation as resolved.
Show resolved Hide resolved

if not self._deferred_store:
self.add_to_hydra_store()

def __contains__(self, key: Union[GroupName, Tuple[GroupName, NodeName]]) -> bool:
"""Checks if group or (group, node-name) exists in zen-store."""
if key is None:
Expand Down Expand Up @@ -1760,8 +1999,9 @@ def add_to_hydra_store(self, overwrite_ok: Optional[bool] = None) -> None:

"""
_store = ConfigStore.instance().store
while self._queue:
entry = _resolve_node(self._queue.popleft(), copy=False)

for key in tuple(self._queue):
entry = _resolve_node(self._internal_repo[key], copy=False)
if (
(
overwrite_ok is False
Expand All @@ -1783,6 +2023,7 @@ def add_to_hydra_store(self, overwrite_ok: Optional[bool] = None) -> None:
f"`overwrite_ok=True` to enable replacing config store entries"
)
_store(**entry)
self._queue.discard(key)

def _exists_in_hydra_store(
self,
Expand Down