In [None]:
#|default_exp ds

# Data Structures

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

## Nested dictionary 

In [None]:
#|export

def flatten_dict(d: dict, sep='.') -> dict:
    def recurse(subdict, parent_key=None):
        result = {}
        for k, v in subdict.items():
            new_key = f"{parent_key}{sep}{k}" if parent_key else k
            if isinstance(v, dict):
                result.update(recurse(v, new_key))
            else:
                result[new_key] = v
        return result

    return recurse(d)

def unflatten_dict(d: dict, sep='.') -> dict:
    res = {}
    for k, v in d.items():
        subkeys = k.split(sep)
        container = res
        for subkey in subkeys[:-1]:
            if subkey not in container:
                container[subkey] = {}
            container = container[subkey]
        container[subkeys[-1]] = v
    return res

In [None]:
#|hide

noop_cases = [{}, {'a': 1, 'b': 2}]
for d in noop_cases:
    test_eq(flatten_dict(d), d)
    test_eq(unflatten_dict(d), d)

In [None]:
#|hide

nested_dict = {
    'dataset_path': 'a/b/c/d',
    'train': {
        'lr': 1e-4,
        'n_epoch': 10,
        'early_stop': {
            'patience': 10,
            'metric': 'val_loss',
        }
    },
    'wandb': {
        'username': 'bdsaglam',
        'project': 'project-x',
    }
}

flat_dict = {
    'dataset_path': 'a/b/c/d',
    'train/lr': 0.0001,
    'train/n_epoch': 10,
    'train/early_stop/patience': 10,
    'train/early_stop/metric': 'val_loss',
    'wandb/username': 'bdsaglam',
    'wandb/project': 'project-x',
}

test_eq(flatten_dict(nested_dict, sep='/'), flat_dict)
test_eq(unflatten_dict(flat_dict, sep='/'), nested_dict)

In [None]:
#|export

class NestedDict(dict):
    def __init__(self, data, sep='.'):
        super().__init__(data)
        self.sep = sep
    
    def at(self, keys: str | list | tuple, default=None):
        if isinstance(keys, str):
            keys = keys.split(self.sep)
        node = self
        for key in keys:
            if key not in node:
                return default
            node = node.get(key)
        return node

    def set(self, keys: str | list | tuple, value):
        if isinstance(keys, str):
            keys = keys.split(self.sep)
        node = self
        last_key = keys.pop()
        for key in keys:
            if key not in node:
                node[key] = dict()
            node = node[key]
        node[last_key] = value

    def flat(self) -> dict:
        return flatten_dict(self, sep=self.sep)
    
    @classmethod
    def from_flat_dict(cls, data, sep='.'):
        return cls(unflatten_dict(data, sep=sep))
     

In [None]:
#|hide
nested_dict = NestedDict(nested_dict, sep='.')

test_eq(nested_dict.at('wandb'), nested_dict['wandb'])
test_eq(nested_dict.at(['wandb']), nested_dict['wandb'])
test_eq(nested_dict.at('wandb.username'), 'bdsaglam')
test_eq(nested_dict.at(['train', 'lr']), nested_dict['train']['lr'])
test_eq(nested_dict.at('a.b.c'), None)
test_eq(nested_dict.at('train.non-existing-field'), None)
test_eq(nested_dict.at('train.non-existing-field', 0), 0)

nested_dict.set('dataset_path', '/newpath')
test_eq(nested_dict.at('dataset_path'), '/newpath')
nested_dict.set('train.lr', 1)
test_eq(nested_dict.at('train.lr'), 1)
nested_dict.set('train.optimizer.name', 'adam')
nested_dict.set('train.optimizer.momentum', 0.9)
test_eq(nested_dict.at('train.optimizer.name'), 'adam')
test_eq(nested_dict.at('train.optimizer.momentum'), 0.9)

In [None]:
#|hide
nested_dict = NestedDict.from_flat_dict({'a/b/c': 1, 'd': 2}, sep='/')
test_eq(dict(nested_dict), {'a': {'b': {'c': 1}}, 'd': 2})

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()