In [None]:
#default_exp hdict

In [None]:
#export
from box import Box

### any dict to hierarchical dict
def hdict_get(dic, hkey, default=None):
    hkey = hkey.split('.') if type(hkey)==str else hkey
    if len(hkey)>1:
        par = hdict_get(dic,hkey[:-1],type(dic)())
    else:
        par = dic
    return par.setdefault(hkey[-1], default)
def hdict_set(dic, hkey, val):
    hkey = hkey.split('.') if type(hkey)==str else hkey
    if len(hkey)>1:
        par = hdict_get(dic, hkey[:-1], type(dic)())
    else:
        par = dic
    par[hkey[-1]]=val
def hdict_keys(dic):
    children = [k for k in dic.keys() if isinstance(dic[k],dict)] # 
    lower = [f'{y}.{x}' for y in children for x in hdict_keys(dic[y])]
    current = [f'{y}' for y in dic.keys() if y not in children]
    return current + lower 
def hdict_flatten(dic):
    keys = hdict_keys(dic)
    return type(dic)([(k, hdict_get(dic,k)) for k in keys])
def hdict_fromflat(dic):
    d = type(dic)()
    for k,v in dic.items():
        hdict_set(d,k,v)
    return d
def hdict_subset(dic, hkeys, complement=False):
    if type(hkeys)==str:
        hkeys=[hkeys]
    fd = hdict_flatten(dic)
    d = type(dic)()
    if complement:
        for k in fd.keys():
            if all([not k.startswith(x) for x in hkeys]):
                d[k] = fd[k]
    else:
        for k in fd.keys():
            if any([k.startswith(x) for x in hkeys]):
                d[k] = fd[k]
    return d
def hdict_leaves(dic, hkey):
    # return non dict children of hkey
    tgt = hdict_get(dic,hkey)
    if isinstance(tgt, dict):
        return type(dic)({k: v for k,v in tgt.items() if not isinstance(v,dict)})
    return type(dic)()
def hdict_overridden(dic, hkey):
    # concatenate all hkey path elements, deeper ones override upper
    hkey = hkey.split('.') if type(hkey)==str else hkey
    leaves = [hdict_leaves(hkey[:i]) for i in range(len(hkey)+1)]
    cur = leaves[0]
    for i in range(1,len(leaves)):
        cur.update(leaves[i])
        return cur
def ptoml(b: Box)->None:
    print(b.to_toml())
    
class HBox(Box):
    """
        Hierarchical key (keys concatenated with '.') corresponds to 
        hierarchical (dictionary) structure. 
        
        e.g. d['a.b.c']=1 indicates accessing {'a':{'b':{'c':1}}} 
        
    """
    _g = Box.__getitem__
    _s = Box.__setitem__
    _c = Box.__contains__
    def __repr__(self) -> str:
        return f'<{self.__class__.__name__}: {self.to_dict()}>'
    def _tohkey(self, hkey):
        return tuple(hkey.split('.')) if type(hkey)==str else tuple(hkey)
    def __getitem__(self, hkey, _ignore_default=False):
        hkey = self._tohkey(hkey)
        cur = self
        for k in hkey:
            cur = HBox._g(cur, k, _ignore_default)
        return cur
    def __contains__(self, hkey):
        hkey = self._tohkey(hkey)
        cur = self
        for k in hkey:
            if not HBox._c(cur, k):
                return False
            cur = HBox._g(cur, k)
        return True
    def __setitem__(self, hkey, val):
        hkey = hkey = self._tohkey(hkey)
        cur = self
        for k in hkey[:-1]:
            try:
                cur = HBox._g(cur,k)
            except:
                HBox._s(cur, k, {})
                cur = HBox._g(cur,k)
        HBox._s(cur, hkey[-1], val)
    def hkeys(self):
        def _hkeys(dic):
            children = [k for k in dic.keys() if isinstance(dic[k],dict)] # 
            lower = [f'{y}.{x}' for y in children for x in _hkeys(dic[y])]
            current = [f'{y}' for y in dic.keys() if y not in children]
            return current + lower 
        return _hkeys(self)
    def flatten(self):
        return Box([(k, self[k]) for k in self.hkeys()])
    @classmethod
    def fromflat(cls, dic):
        d = cls()
        for k,v in dic.items():
            d[k] = v
        return d
    def subset(self, hkeys, complement=False):
        if type(hkeys)==str:
            hkeys = [hkeys]
        fd = self.flatten()
        d = self.__class__()
        if complement:
            for k in fd.keys():
                if all([not k.startswith(x) for x in hkeys]):
                    d[k] = fd[k]
        else:            
            for k in fd.keys():
                if any([k.startswith(x) for x in hkeys]):
                    d[k] = fd[k]
        return d
    def leaves(self, hkey):
        # return non dict children of hkey
        tgt = self[hkey]
        if isinstance(tgt, dict):
            return Box({k: v for k,v in tgt.items() if not isinstance(v,dict)})
        return Box()
    def overridden(self, hkey):
        # concatenate all hkey path elements, deeper ones override upper
        hkey = hkey = self._tohkey(hkey)
        leaves = [self.leaves(hkey[:i]) for i in range(len(hkey)+1) if hkey[:i] in self]
        cur = leaves[0]
        for i in range(1,len(leaves)):
            cur.update(leaves[i])
        return cur
    def copy(self):
        return self.__class__(Box.copy(self))
    def merge(self, other):
        if not isinstance(other, HBox):
            other = HBox(other)
        sfd = self.flatten()
        ofd = other.flatten()
        sfd.update(ofd)
        return HBox.fromflat(sfd)