In [None]:
# default_exp utils

# Utils

> Collection of useful functions.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import os
import numpy as np

from typing import Iterable, TypeVar, Generator
from plum import dispatch
from pathlib import Path
from functools import reduce

function = type(lambda: ())
T = TypeVar('T')

## Basics

In [None]:
#export
def identity(x: T) -> T:
    """Indentity function."""

    return x

In [None]:
#export
def simplify(x): 
    """Return an object of an iterable if it is lonely."""
    
    @dispatch
    def _simplify(x): return x

    @dispatch
    def _simplify(fn: function): 
        try:
            return fn()
        except TypeError:
            return fn

    @dispatch
    def _simplify(i: Iterable): return next(i.__iter__()) if len(i) == 1 else i

    return _simplify(x)

The simplify function is used to de-nest an iterable with a single element in it, as for instance [1], while leaving everything else constant. It can also exchange a function for its default argument.

In [None]:
simplify({1})

1

In [None]:
simplify(simplify)(lambda x='lul': 2*x)

'omegalul'

In [None]:
#export
def listify(x, *args):
    """Convert `x` to a `list`."""
    if args:
        x = (x,) + args

    if x is None:
        result = []
    elif isinstance(x, list): result = x
    elif isinstance(x, str) or hasattr(x, "__array__") or hasattr(x, "iloc"):
        result = [x]
    elif isinstance(x, (Iterable, Generator)):
        result = list(x)
    else:
        result = [x]
        
    return result

What's very convenient is that it leaves lists invariant (it doen't nest them into a new list).

In [None]:
listify([1, 2])

[1, 2]

In [None]:
listify(1, 2, 3)

[1, 2, 3]

In [None]:
#export
def setify(x, *args):
    """Convert `x` to a `set`."""

    return set(listify(x, *args))

In [None]:
setify(1, 2, 3)

{1, 2, 3}

In [None]:
#export
def tuplify(x, *args):
    """Convert `x` to a `tuple`."""

    return tuple(listify(x, *args))

In [None]:
tuplify(1)

(1,)

In [None]:
#export
def merge_tfms(*tfms):
    """Merge two dictionnaries by stacking common key into list."""

    def _merge_tfms(tf1, tf2):
        return {
            k: simplify(listify(setify(listify(tf1.get(k)) + listify(tf2.get(k)))))
            for k in {**tf1, **tf2}
        }
    
    return reduce(_merge_tfms, tfms, dict())

In [None]:

merge_tfms(
    {'animals': ['cats', 'dog'], 'colors': 'blue'}, 
    {'animals': 'cats', 'colors': 'red', 'OS': 'i use arch btw'}
)

{'animals': ['cats', 'dog'], 'colors': ['red', 'blue'], 'OS': 'i use arch btw'}

In [None]:
#export
def compose(*functions):
    """Compose an arbitrary number of functions."""

    def _compose(fn1, fn2):
        return lambda x: fn1(fn2(x))

    return reduce(_compose, functions, identity)

In [None]:
#export
def pipe(*functions):
    """Pipe an arbitrary number of functions."""

    return compose(*functions[::-1])

In [None]:
#export
def flow(data, *functions):
    """Flow `data` through a list of functions."""

    return pipe(*functions)(data)

## File manipulation helper

In [None]:
#export
def get_files(path, extensions=None, recurse=False, folders=None, followlinks=True):
    """Get all those file names."""
    path = Path(path)
    folders = listify(folders)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}

    def simple_getter(p, fs, extensions=None):
        p = Path(p)
        res = [
            p / f
            for f in fs
            if not f.startswith(".")
            and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)
        ]
        return res

    if recurse:
        result = []
        for i, (p, d, f) in enumerate(os.walk(path, followlinks=followlinks)):
            if len(folders) != 0 and i == 0:
                d[:] = [o for o in d if o in folders]
            else:
                d[:] = [o for o in d if not o.startswith(".")]
            if len(folders) != 0 and i == 0 and "." not in folders:
                continue
            result += simple_getter(p, f, extensions)
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        result = simple_getter(path, f, extensions)
    return list(map(str, result))

In [None]:
def get_array(suffix):
    """Return a getter that loads a numpy array with given suffix."""
    return lambda x: np.load(x.with_suffix(f'{suffix}.npy'), allow_pickle=True)

In [None]:
def save_array(array, fname, suffix):
    """Save an array with the given name and suffix."""
    if not suffix.startwith("."):
        suffix = "." + suffix

    fname = Path(fname)

    return np.save(array, fname.with_suffix(suffix))

In [None]:
#export
def save_dataset(data):
    return 'NotImplementedError'