Skip to content

Commit

Permalink
Add support to update_from_flattened_dict for updating a single specific
Browse files Browse the repository at this point in the history
(indexed) element of list/tuple. of a ConfigDict.

PiperOrigin-RevId: 549703559
Change-Id: If4682b3b1351082b9aa957093e3783c4b121640c
  • Loading branch information
ML Collections Contributor committed Jul 20, 2023
1 parent ec0d364 commit 897bf21
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 19 deletions.
80 changes: 61 additions & 19 deletions ml_collections/config_dict/config_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import inspect
import json
import operator
from typing import Any, Mapping, Optional
import re
from typing import Any, Mapping, Optional, Tuple

from absl import logging

Expand Down Expand Up @@ -157,6 +158,18 @@ def _get_computed_value(value_or_fieldreference):
return value_or_fieldreference


def _parse_key(key: str) -> Tuple[str, Optional[int]]:
"""Parse a ConfigDict key into to it's initial part and index (if any)."""
key = key.split('.')[0]
index_match = re.match("(.*)\[([0-9]+)\]", key)
if index_match:
key = index_match.group(1)
index = int(index_match.group(2))
else:
index = None
return key, index


class _Op(collections.namedtuple('_Op', ['fn', 'args'])):
"""A named tuple representing a lazily computed op.
Expand Down Expand Up @@ -1344,6 +1357,17 @@ def update(self, *other, **kwargs):
else:
self[key] = value

def _update_value(self, key, index, value):
if index is None:
self[key] = value
elif isinstance(self[key], list):
self[key][index] = value
elif isinstance(self[key], tuple):
# Tuples are immutable, so convert to list, update and convert back.
tuple_as_list = list(self[key])
tuple_as_list[index] = value
self[key] = tuple(tuple_as_list)

def update_from_flattened_dict(self, flattened_dict, strip_prefix=''):
"""In-place updates values taken from a flattened dict.
Expand Down Expand Up @@ -1416,31 +1440,49 @@ def update_from_flattened_dict(self, flattened_dict, strip_prefix=''):

# Keep track of any children that we want to update. Make sure that we
# recurse into each one only once.
children_to_update = set()
children_to_update = {}

for full_key, value in six.iteritems(interesting_items):
key = full_key[len(strip_prefix):] if strip_prefix else full_key

# If the path is hierarchical, we'll need to tell the first component
# to update itself.
full_child = key.split('.')[0]

# Check to see if we are trying to update a single element of a tuple/list
#
# TODO(kkg): The key/index parsing & handling logic below duplicates
# similar logic in the config_flags/config_path module. Ideally we should
# refactor the code to reuse the 'config_path' module here - but that is
# likely a significant effort since that module already depends on this
# leading to a circular dependency.
child, index = _parse_key(full_child)

if not child in self:
raise KeyError('Key "{}" cannot be set as "{}" was not found.'
.format(full_key, strip_prefix + child))

if index is not None and not isinstance(self[child], (list, tuple)):
raise KeyError('Key "{}" cannot be set as "{}" is not a tuple/list.'
.format(full_key, strip_prefix + child))

if '.' in key:
# If the path is hierarchical, we'll need to tell the first component
# to update itself.
child = key.split('.')[0]
if child in self:
if isinstance(self[child], ConfigDict):
children_to_update.add(child)
else:
raise KeyError('Key "{}" cannot be updated as "{}" is not a '
'ConfigDict.'.format(full_key, strip_prefix + child))
else:
raise KeyError('Key "{}" cannot be set as "{}" was not found.'
.format(full_key, strip_prefix + child))
child_value = self[child] if index is None else self[child][index]
if not isinstance(child_value, ConfigDict):
raise KeyError(
'Key "{}" cannot be updated as "{}" is not a ConfigDict ({}).'
.format(full_key, strip_prefix + full_child, type(child_value))
)

children_to_update[full_child] = child_value
else:
self[key] = value
self._update_value(child, index, value)

for child in children_to_update:
child_strip_prefix = strip_prefix + child + '.'
self[child].update_from_flattened_dict(interesting_items,
child_strip_prefix)
for full_child, child_value in children_to_update.items():
child_strip_prefix = f'{strip_prefix}{full_child}.'
child_value.update_from_flattened_dict(
interesting_items, child_strip_prefix
)


def _frozenconfigdict_valid_input(obj, ancestor_list=None):
Expand Down
28 changes: 28 additions & 0 deletions ml_collections/config_dict/tests/config_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,34 @@ def testUpdateFromFlattenedTupleListConversion(self):
self.assertIsInstance(cfg.b.c.d, tuple)
self.assertEqual(cfg.b.c.d, (2, 4, 6, 8))

def testUpdateFromFlattenedTupleListIndexConversion(self):
cfg = config_dict.ConfigDict({
'a': 1,
'b': {
'c': {
'd': (1, 2, 3, 4, 5),
},
'e': [
config_dict.ConfigDict({
'f': 4,
'g': 5,
}),
config_dict.ConfigDict({
'f': 6,
'g': 7,
}),
],
}
})
updates = {
'b.c.d[2]': 9,
'b.e[1].f': 12,
}
cfg.update_from_flattened_dict(updates)
self.assertIsInstance(cfg.b.c.d, tuple)
self.assertEqual(cfg.b.c.d, (1, 2, 9, 4, 5))
self.assertCountEqual(cfg.b.e[1], {'f': 12, 'g': 7})

def testDecodeError(self):
# ConfigDict containing two strings with incompatible encodings.
cfg = config_dict.ConfigDict({
Expand Down

0 comments on commit 897bf21

Please sign in to comment.