In [None]:
#| default_exp utils

# General utilities

In [None]:
#| hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#| export

import os
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Union, List

import pandas as pd
from fastcore.basics import patch

In [None]:
#| export

@patch
def ls_sorted(self:Path):
    "ls but sorts files by name numerically"
    return self.ls().sorted(key=lambda f: int(f.with_suffix('').name))

In [None]:
#| export

def flatten_dict(d: Dict, sep='.') -> Dict:
    return pd.json_normalize(d, sep=sep).to_dict(orient='records')[0]

def make_nested_dict(subkeys: List, value):
    d = {}
    for k, v in d.items():
        pass

def unflatten_dict(d: Dict, sep='.') -> Dict:
    res = {}
    for k, v in d.items():
        subkeys = k.split(sep)
        container = res
        for subkey in subkeys[:-1]:
            if subkey not in container:
                container[subkey] = {}
            container = container[subkey]
        container[subkeys[-1]] = v
    return res


In [None]:
#| hide

nested_dict = {
    'dataset_path': 'a/b/c/d',
    'train': {
        'lr': 1e-4,
        'n_epoch': 10,
        'early_stop': {
            'patience': 10,
            'metric': 'val_loss',
        }
    },
    'wandb': {
        'username': 'bdsaglam',
        'project': 'project-x',
    }
}

flat_dict = {
    'dataset_path': 'a/b/c/d',
    'train/lr': 0.0001,
    'train/n_epoch': 10,
    'train/early_stop/patience': 10,
    'train/early_stop/metric': 'val_loss',
    'wandb/username': 'bdsaglam',
    'wandb/project': 'project-x',
}

test_eq(flatten_dict(nested_dict, sep='/'), flat_dict)
test_eq(unflatten_dict(flat_dict, sep='/'), nested_dict)

In [None]:
#| export
from collections import Counter

def most_common(lst):
    """returns the most common element of a collection"""
    return Counter(lst).most_common(1)[0][0]

In [None]:
#| hide
test_eq(most_common([1,1,1,2,2,3,3,3,3,4,4]), 3)
test_eq(most_common([1,1,1,2,2,3,3,3,4,4]), 1)
test_eq(most_common([0]), 0)

In [None]:
#| export

# ref: https://dev.to/teckert/changing-directory-with-a-python-context-manager-2bj8
@contextmanager
def set_dir(path: Union[Path, str]):
    """Sets the cwd within the context"""
    origin = Path().absolute()
    try:
        os.chdir(path)
        yield
    finally:
        os.chdir(origin)

In [None]:
#| export
from datetime import datetime

def generate_time_id(dt=None):
    """generates a string id from given datetime or now"""
    return (dt or datetime.now()).isoformat().rsplit('.', 1)[0].replace(':', '-')

In [None]:
#| hide

test_eq(generate_time_id(datetime(2022, 1, 1, 1, 1, 1)), '2022-01-01T01-01-01')

time_id = generate_time_id()
test_eq(len(time_id), 19)
test_eq(time_id.count('-'), 4)

In [None]:
#| export

def get_node(
    tree: Dict, # tree to traverse
    path: str, # path of node
    sep: str = '.' # separator used in path
): # the node
    """returns the node from a tree (dict) by path"""
    if path is None or path == '':
        return tree
    node = tree
    for field in path.split(sep):
        if field in node:
            node = node[field]
        else:
            return None
    return node


In [None]:
#| hide
tree = {
    'a': 1,
    'b': {
        'c': {
            'd': 2,
        },
        'e': 3,
    },
    'f': 4,
}

test_eq(get_node(tree, 'a'), 1)
test_eq(get_node(tree, 'b.c.d'), 2)
test_eq(get_node(tree, 'b.c.z'), None)
test_eq(get_node(tree, 'b.e'), 3)
test_eq(get_node(tree, 'b/e', sep='/'), 3)
test_eq(get_node(tree, None), tree)

In [None]:
#| export
def apply_nested(tree: dict, path: str, func, sep: str = '.'):
    parts = path.split(sep)
    parent_node = get_node(tree, sep.join(parts[:-1]))
    parent_node[parts[-1]] = func(parent_node[parts[-1]])
    return tree

In [None]:
#| hide
tree = {
    'a': 5,
    'b': {
        'c': {
            'd': 2,
        },
        'e': 3,
    },
    'f': 4,
}

func = lambda x: x*x

for path in ['a', 'b.c.d', 'f']:
    apply_nested(tree, path, func)

test_eq(tree['a'], 25)
test_eq(tree['b']['c']['d'], 4)
test_eq(tree['b']['e'], 3)
test_eq(tree['f'], 16)

In [None]:
#| export

def resolve_path(config, field_path, sep='.'):
    func = lambda s: str(Path(s).resolve())
    return apply_nested(config, field_path, func, sep)

In [None]:
#| hide
config = {
    'data_path': './a/b/c',
    'model': {
        'save_path': './path/to/artifact'
    }
}

resolve_path(config, 'data_path')
resolve_path(config, 'model.save_path')

assert not config['data_path'].startswith('.')
assert not config['model']['save_path'].startswith('.')

In [None]:
#| hide
origin = os.getcwd()
with set_dir('/opt'):
    test_eq(os.getcwd(), '/opt')
test_eq(os.getcwd(), origin)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()