In [22]:
#export
from collections import OrderedDict
import math
import re

import numpy as np


__all__ = ['default', 'merge_dicts', 'to_snake_case', 'pairs', 'classname',
           'to_list', 'autoformat', 'is_scalar', 'broadcast', 'unwrap_if_single',
           'from_torch']

In [None]:
#export
def default(x, fallback=None):
    return x if x is not None else fallback

In [None]:
#export
def merge_dicts(ds):
    """Merges a list of dictionaries into single dictionary.
    
    The order of dicts in the list affects the values of keys in the
    returned dict.
    """
    merged = OrderedDict()
    for d in ds:
        for k, v in d.items():
            merged[k] = v
    return merged

In [None]:
#export
def to_snake_case(string):
    s = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s).lower()

In [None]:
#export
def pairs(seq):
    """Returns list of adjacent pairs: [1, 2, 3] -> [(1, 2), (2, 3)]."""
    yield from zip(seq[:-1], seq[1:])

In [None]:
#export
def classname(x):
    return x.__class__.__name__

In [17]:
#export
def to_list(obj):
    """Converts iterable into list or wraps a scalar value with list."""
    if isinstance(obj, str):
        return [obj]
    return list(obj) if hasattr(obj, '__len__') or hasattr(obj, '__next__') else [obj]

In [18]:
def gen(x):
    for i in range(x): yield i
    
assert to_list([1, 2, 3]) == [1, 2, 3]
assert to_list((1, 2, 3)) == [1, 2, 3]
assert to_list({'a': 1, 'b': 2, 'c': 3}) == ['a', 'b', 'c']
assert to_list(1) == [1]
assert to_list(gen(5)) == list(range(5))
assert to_list('string') == ['string']

In [19]:
#export
def autoformat(v):
    """Tryies to convert value into a string using the best possible representation."""
    
    return (f'{v:d}' if isinstance(v, (int, np.int16, np.int32, np.int64)) else
            f'{v:.4f}' if isinstance(v, (float, np.float16, np.float32, np.float64)) else
            f'{str(v)}')

In [25]:
assert autoformat(1) == '1'
assert autoformat(1.11111) == '1.1111'
assert autoformat('string') == 'string'
assert autoformat(np.float16(1)) == '1.0000'

In [28]:
#export
def is_scalar(obj):
    return isinstance(obj, (int, float, str, complex))

In [29]:
assert all(is_scalar(x) for x in (1, 1., 1j+0, 'string'))

In [30]:
#export
def broadcast(obj, pad=1):
    """Convenience function to unwrap collections and broadcast scalars."""
    if is_scalar(obj): 
        return [obj]*pad
    return obj

In [31]:
assert broadcast([1, 2, 3]) == [1, 2, 3]
assert broadcast(1) == [1]

In [None]:
#export
def unwrap_if_single(obj):
    """Converts obj collection into a scalar if it contains single element only."""
    return obj[0] if len(obj) == 1 else obj

In [None]:
assert unwrap_if_single([1]) == 1
assert unwrap_i

In [4]:
#export
def from_torch(tensor):
    """Converts torch tensor into Numpy array or scalar."""
    obj = tensor.detach().cpu()
    if not obj.shape:
        return obj.item()
    return obj.numpy()

In [7]:
import torch
import numpy as np
assert np.allclose(from_torch(torch.tensor([1, 2, 3])), np.array([1, 2, 3]))

In [1]:
#export
def make_axis_if_needed(ax=None, **params):
    """Creates matplotlib axis but only if 'ax' is None."""
    if ax is None:
        _, ax = plt.subplots(1, 1, **params)
    return ax

In [None]:
#export
def calculate_layout(num_axes, n_rows=None, n_cols=None):
    """Calculates number of rows/columns required to fit `num_axes` plots
    onto figure if specific number of columns/rows is specified.
    """
    if n_rows is not None and n_cols is not None:
        raise ValueError(
            'cannot derive number of rows/columns if both values provided')
    if n_rows is None and n_cols is None:
        n_cols = 2
    if n_rows is None:
        n_rows = max(1, math.ceil(num_axes / n_cols))
    else:
        n_cols = max(1, math.ceil(num_axes / n_rows))
    return n_rows, n_cols