In [56]:
import re
import json
from copy import deepcopy
from ast import literal_eval
from collections import defaultdict

In [28]:
def pretty_print(d: dict):
    print(json.dumps(d, indent="  "))

In [29]:
def merge_dicts(first: dict, second: dict):
    """
    Merges two dictionaries and all their subsequent dictionaries.
    In case both dictionaries contain the same key, which is not another dictionary, the latter one is used.

    This merges in contrast to dict.update() all subdicts and its items
    instead of overriding the former with the latter.
    """
    fk = set(first.keys())
    sk = set(second.keys())
    common_keys = fk.intersection(sk)

    z = {}
    for k in common_keys:
        if isinstance(first[k], dict) and isinstance(second[k], dict):
            z[k] = merge_dicts(first[k], second[k])
        else:
            z[k] = deepcopy(second[k])

    for k in fk - common_keys:
        z[k] = deepcopy(first[k])

    for k in sk - common_keys:
        z[k] = deepcopy(second[k])

    return z

In [125]:
def nested_dict_set(key: str, value: any, d: dict, sep: str=".") -> dict:
    """
    Iterates through the provided nested dictionary and searches for the provided key,
    and then sets its value to the one specified. As the dictionary might consist
    of other dictionaries, a separator (sep) can be used to access
    sub-dictionaries.

    Note that as dictionaries are mutable and passed by reference, we modify the dictionary in-place
    to make the requested changes. Still, the dict is returned to suit all use-cases.

    Args:
        key (str): the key to look for in the dictionary
        value (any): the value to set the dictionary item to
        d (dict): the dictionary for which key values should be replaced
        sep (str): the string/char that splits the different levels of the dict keys

    Returns:
        (dict): the modified dictionary (Note that the dictionary is modified in-place nonetheless.)
    """

    di = d
    steps = key.split(sep)

    while len(steps) > 0:
        step = steps.pop(0)
        
        # basically validates and searches for list indices
        res = re.search(pattern, step)
        if res is None:
            raise KeyError(f"format of key '{step}' not supported.")
        _, _, key_for_list, list_idx, key_for_dict = res.groups()
        list_idx_provided = list_idx is not None
        
        print(f"{key_for_list=}, {list_idx=}, {list_idx_provided=}, {key_for_dict=}")
        
        k = key_for_list or key_for_dict
        list_idx = int(list_idx) if list_idx_provided else None
        
        # look up whether key exists in dict
        if isinstance(di, dict) and k in di:
            
            # lists and dicts are all passed by reference, thus we can modify them in-place 
            element = di[k]
            if list_idx_provided:
                if not isinstance(element, list):
                    raise KeyError(f"list index provided in key '{k}', but no list found.")
                    
            # key found
            if len(steps) == 0:
                if list_idx_provided:
                    element[list_idx] = value
                else: 
                    element = value
                return d

            # continue search
            di = element[list_idx] if list_idx_provided else element

        # failed to find key
        else:
            break

    raise KeyError(key)

In [157]:
res = "ab22".split(".", maxsplit=1)
k, rest = res if len(res) == 2 else res, None
k, rest

(['ab22'], None)

In [188]:
def nested_dict_set(key: str, value: any, d: dict, sep: str=".") -> dict:
    """
    Iterates through the provided nested dictionary and searches for the provided key,
    and then sets its value to the one specified. As the dictionary might consist
    of other dictionaries, a separator (sep) can be used to access
    sub-dictionaries.

    Note that as dictionaries are mutable and passed by reference, we modify the dictionary in-place
    to make the requested changes. Still, the dict is returned to suit all use-cases.

    Args:
        key (str): the key to look for in the dictionary
        value (any): the value to set the dictionary item to
        d (dict): the dictionary for which key values should be replaced
        sep (str): the string/char that splits the different levels of the dict keys

    Returns:
        (dict): the modified dictionary (Note that the dictionary is modified in-place nonetheless.)
    """
    return _nested_dict_set(key, key, value, d, sep)
    
    
def _split_with_rest(s: str, sep: str, n_splits=1):
    """ 
    splits the specified string, if string was split less than "n_splits" times, 
    for the remaining splits, None will be returned
    """
    result = s.split(sep, maxsplit=n_splits)
    return tuple(list(result) + [None] * (n_splits + 1 - len(result)))


def _nested_dict_set(full_key: str, key: str, value: any, d: dict, sep: str=".") -> dict:

    di = d
    k, rest = _split_with_rest(key, sep, n_splits=1)
    
    # basically validates and searches for list indices
    print(f"searching for pattern {k}")
    res = re.search(pattern, k)
    if res is None:
        raise KeyError(f"format of key '{k}' not supported.")
    _, _, key_for_list, list_idx, key_for_dict = res.groups()
    list_idx_provided = list_idx is not None
    
    print(f"{key_for_list=}, {list_idx=}, {list_idx_provided=}, {key_for_dict=}")

    k = key_for_list or key_for_dict
    list_idx = int(list_idx) if list_idx_provided else None

    # look up whether key exists in dict
    if isinstance(di, dict) and k in di:
        if list_idx_provided:
            if not isinstance(di[k], list):
                raise KeyError(f"list index provided with key {full_key} but no list found.")

        if rest is None:
            if list_idx_provided:
                di[k][list_idx] = value
            else: 
                di[k] = value
            return di

        # continue search
        return _nested_dict_set(full_key, rest, value, di[k][list_idx] if list_idx_provided else di[k])
    
    else: 
        # failed to find key
        raise KeyError(full_key)

In [189]:
_split_with_rest("ab.c.d.e", ".", 4)

('ab', 'c', 'd', 'e', None)

In [190]:
re.findall(r"((\w+)\[(\d+)\])|(\w+)", "optim[1]")

[('optim[1]', 'optim', '1', '')]

In [191]:
pattern = r"^((([a-zA-Z]+\w*)\[(\d+)\])|([a-zA-Z]+\w*))$"
res = re.search(pattern, "opti[1]").groups()
display(res)

_, _, key_for_list, list_idx, key_for_dict = res
int(list_idx)

('opti[1]', 'opti[1]', 'opti', '1', None)

1

In [207]:
a = {
    "optim": {     
        "lr": 1e-3,
        "weight_decay": 1e-4,   
    },
    
    "adv_groups": [
        { 
            "feature": "gender",
        },
        {
            "grad_scaling": 1
        }
    ]
}

grid_config = {
    "optim.lr": 1e-3,
    "adv_groups[0]": "hannes",
    "adv_groups[1]": {
        "a": 1,
        "b": 2
    },
    "optim": "overriden things"
}

In [208]:
b = deepcopy(a)
for k, v in grid_config.items():
    print(f"Setting '{v}' for nested key '{k}'")
    nested_dict_set(k, v, b)
    print()
pretty_print(b)

Setting '0.001' for nested key 'optim.lr'
searching for pattern optim
key_for_list=None, list_idx=None, list_idx_provided=False, key_for_dict='optim'
searching for pattern lr
key_for_list=None, list_idx=None, list_idx_provided=False, key_for_dict='lr'

Setting 'hannes' for nested key 'adv_groups[0]'
searching for pattern adv_groups[0]
key_for_list='adv_groups', list_idx='0', list_idx_provided=True, key_for_dict=None

Setting '{'a': 1, 'b': 2}' for nested key 'adv_groups[1]'
searching for pattern adv_groups[1]
key_for_list='adv_groups', list_idx='1', list_idx_provided=True, key_for_dict=None

Setting 'overriden things' for nested key 'optim'
searching for pattern optim
key_for_list=None, list_idx=None, list_idx_provided=False, key_for_dict='optim'

{
  "optim": "overriden things",
  "adv_groups": [
    "hannes",
    {
      "a": 1,
      "b": 2
    }
  ]
}


In [109]:
l = b["adv_groups"][0]
l["feature"] = 2
b

{'optim': {'lr': 0.001, 'weight_decay': 0.0001},
 'adv_groups': [{'feature': 2}, {'grad_scaling': 1}]}

In [2]:
test_logit_groups = [
    [["g1_adv1_log1", "g1_adv2_log1", "g1_adv2_log1"], ["g2_adv1_log1", "g2_adv2_log1"], ["g3_adv1_log1"]],
    [["g1_adv1_log2", "g1_adv2_log2", "g1_adv2_log2"], ["g2_adv1_log2", "g2_adv2_log2"], ["g3_adv1_log2"]],
    [["g1_adv1_log3", "g1_adv2_log3", "g1_adv2_log3"], ["g2_adv1_log3", "g2_adv2_log3"], ["g3_adv1_log3"]],
]

In [8]:
result = [list(grp) for grp in zip(*test_logit_groups)]
result = list(list(grp) for grp in zip(*test_logit_groups))
result

[[['g1_adv1_log1', 'g1_adv2_log1', 'g1_adv2_log1'],
  ['g1_adv1_log2', 'g1_adv2_log2', 'g1_adv2_log2'],
  ['g1_adv1_log3', 'g1_adv2_log3', 'g1_adv2_log3']],
 [['g2_adv1_log1', 'g2_adv2_log1'],
  ['g2_adv1_log2', 'g2_adv2_log2'],
  ['g2_adv1_log3', 'g2_adv2_log3']],
 [['g3_adv1_log1'], ['g3_adv1_log2'], ['g3_adv1_log3']]]

In [7]:
result[0]

[['g1_adv1_log1', 'g1_adv2_log1', 'g1_adv2_log1'],
 ['g1_adv1_log2', 'g1_adv2_log2', 'g1_adv2_log2'],
 ['g1_adv1_log3', 'g1_adv2_log3', 'g1_adv2_log3']]

In [19]:
from collections import Counter
features = ["A", "B", "C", "A", "C", "A", "C", "A", "C", "A", "C"]

feature_counter = Counter(features)
non_unique_features = [f for f, c in feature_counter.items() if c >= 2]
non_unique_feature_counter = {f: 0 for f in non_unique_features}

feature_unique_names = []
for f in features:
    if f in non_unique_features:
        feature_unique_names.append(f"{f}_{non_unique_feature_counter[f]}")
        non_unique_feature_counter[f] += 1
    else:
        feature_unique_names.append(f)
        
feature_unique_names

['A_0', 'B', 'C_0', 'A_1', 'C_1', 'A_2', 'C_2', 'A_3', 'C_3', 'A_4', 'C_4']

In [15]:
Counter(features)

Counter({'A': 2, 'C': 2, 'B': 1})

In [21]:
from collections import Counter

def create_unique_names(names: list):
    counter = Counter(names)
    non_unique_names = [n for n, c in counter.items() if c >= 2]
    non_unique_counter = {n: 0 for n in non_unique_names}

    unique_names = []
    for n in names:
        if n in non_unique_names:
            unique_names.append(f"{n}_{non_unique_counter[n]}")
            non_unique_counter[n] += 1
        else:
            unique_names.append(n)

    return unique_names


features = ["A", "B", "C", "A", "C", "A", "C", "A", "C", "A", "C"]
create_unique_names(features)


['A_0', 'B', 'C_0', 'A_1', 'C_1', 'A_2', 'C_2', 'A_3', 'C_3', 'A_4', 'C_4']