In [None]:
#default_exp utils

In [None]:
#export
from local.test import *
from local.core import *
from local.imports import *

In [None]:
from local.notebook.showdoc import *

# Utility functions

> Utility functions used in the fastai library

## Basics

In [None]:
# export
def ifnone(a, b):
    "`b` if `a` is None else `a`"
    return b if a is None else a

Since `b if a is None else a` is such a common pattern, we wrap it in a function. However, be careful, because python will evaluate *both* `a` and `b` when calling `ifnone` (which it doesn't do if using the `if` version directly).

In [None]:
test_eq(ifnone(None,1), 1)
test_eq(ifnone(2   ,1), 2)

In [None]:
#export
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
    "Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`"
    attrs = {}
    for f in fld_names: attrs[f] = None
    for f in L(funcs): attrs[f.__name__] = f
    for k,v in flds.items(): attrs[k] = v
    sup = ifnone(sup, ())
    if not isinstance(sup, tuple): sup=(sup,)

    def _init(self, *args, **kwargs):
        for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)
        for k,v in kwargs.items(): setattr(self,k,v)

    def _repr(self):
        return '\n'.join(f'{o}: {getattr(self,o)}' for o in set(dir(self))
                         if not o.startswith('_') and not isinstance(getattr(self,o), types.MethodType))

    if not sup: attrs['__repr__'] = _repr
    attrs['__init__'] = _init
    res = type(nm, sup, attrs)
    if doc is not None: res.__doc__ = doc
    return res

In [None]:
_t = get_class('_t', 'a', b=2)
t = _t()
test_eq(t.a, None)
test_eq(t.b, 2)
t = _t(1, b=3)
test_eq(t.a, 1)
test_eq(t.b, 3)
t = _t(1, 3)
test_eq(t.a, 1)
test_eq(t.b, 3)

Most often you'll want to call `mk_class`, since it adds the class to your module. See `mk_class` for more details and examples of use (which also apply to `get_class`).

In [None]:
#export
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, **flds):
    "Create a class using `get_class` and add to the caller's module"
    if mod is None: mod = inspect.currentframe().f_back.f_locals
    res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, **flds)
    mod[nm] = res

Any `kwargs` will be added as class attributes, and `sup` is an optional (tuple of) base classes.

In [None]:
mk_class('_t', a=1, sup=GetAttr)
t = _t()
test_eq(t.a, 1)
assert(isinstance(t,GetAttr))

A `__init__` is provided that sets attrs for any `kwargs`, and for any `args` (matching by position to fields), along with a `__repr__` which prints all attrs. The docstring is set to `doc`. You can pass `funcs` which will be added as attrs with the function names.

In [None]:
def foo(self): return 1
mk_class('_t', 'a', sup=GetAttr, doc='test doc', funcs=foo)

t = _t(3, b=2)
test_eq(t.a, 3)
test_eq(t.b, 2)
test_eq(t.foo(), 1)
test_eq(t.__doc__, 'test doc')
t

<__main__._t at 0x7f11362e2278>

In [None]:
#export
def wrap_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
    "Decorator: makes function a method of a new class `nm` passing parameters to `mk_class`"
    def _inner(f):
        mk_class(nm, *fld_names, sup=sup, doc=doc, funcs=L(funcs)+f, mod=f.__globals__, **flds)
        return f
    return _inner

In [None]:
@wrap_class('_t', a=2)
def bar(self,x): return x+1

t = _t()
test_eq(t.a, 2)
test_eq(t.bar(3), 4)

In [None]:
show_doc(noop)

<h4 id="noop" class="doc_header"><code>noop</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/local/imports.py#L57" class="source_link" style="float:right">[source]</a></h4>

> <code>noop</code>(**`x`**=*`None`*, **\*`args`**, **\*\*`kwargs`**)

Do nothing

In [None]:
noop()
test_eq(noop(1),1)

In [None]:
show_doc(noops)

<h4 id="noops" class="doc_header"><code>noops</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/local/imports.py#L61" class="source_link" style="float:right">[source]</a></h4>

> <code>noops</code>(**`x`**=*`None`*, **\*`args`**, **\*\*`kwargs`**)

Do nothing (method)

In [None]:
mk_class('_t', foo=noops)
test_eq(_t().foo(1),1)

In [None]:
#export
def store_attr(self, nms):
    "Store params named in comma-separated `nms` from calling context into attrs in `self`"
    mod = inspect.currentframe().f_back.f_locals
    for n in re.split(', *', nms): setattr(self,n,mod[n])

In [None]:
class T:
    def __init__(self, a,b,c): store_attr(self, 'a,b, c')

t = T(1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2

In [None]:
#export
def attrdict(o, *ks):
    "Dict from each `k` in `ks` to `getattr(o,k)`"
    return {k:getattr(o,k) for k in ks}

In [None]:
test_eq(attrdict(t,'b','c'), {'b':3, 'c':2})

In [None]:
#export
def properties(cls, *ps):
    "Change attrs in `cls` with names in `ps` to properties"
    for p in ps: setattr(cls,p,property(getattr(cls,p)))

In [None]:
class T:
    def a(self): return 1
    def b(self): return 2
properties(T,'a')

test_eq(T().a,1)
test_eq(T().b(),2)

In [None]:
#export core
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1   = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

test_eq(camel2snake('ClassAreCamel'), 'class_are_camel')

In [None]:
#export
def class2attr(self, cls_name):
    return camel2snake(re.sub(rf'{cls_name}$', '', self.__class__.__name__) or cls_name.lower())

## Collection functions

In [None]:
#export
def tuplify(o, use_list=False, match=None):
    "Make `o` a tuple"
    return tuple(L(o, use_list=use_list, match=match))

In [None]:
test_eq(tuplify(None),())
test_eq(tuplify([1,2,3]),(1,2,3))
test_eq(tuplify(1,match=[1,2,3]),(1,1,1))

In [None]:
#export
def detuplify(x):
    "If `x` is a tuple with one thing, extract it"
    return None if len(x)==0 else x[0] if len(x)==1 else x

In [None]:
test_eq(detuplify(()),None)
test_eq(detuplify([1]),1)
test_eq(detuplify([1,2]), [1,2])

In [None]:
#export
def replicate(item,match):
    "Create tuple of `item` copied `len(match)` times"
    return (item,)*len(match)

In [None]:
t = [1,1]
test_eq(replicate([1,2], t),([1,2],[1,2]))
test_eq(replicate(1, t),(1,1))

In [None]:
#export
def uniqueify(x, sort=False, bidir=False, start=None):
    "Return the unique elements in `x`, optionally `sort`-ed, optionally return the reverse correspondance."
    res = L(x).unique()
    if start is not None: res = start+res
    if sort: res.sort()
    if bidir: return res, res.val2idx()
    return res

In [None]:
# test
test_eq(set(uniqueify([1,1,0,5,0,3])),{0,1,3,5})
test_eq(uniqueify([1,1,0,5,0,3], sort=True),[0,1,3,5])
v,o = uniqueify([1,1,0,5,0,3], bidir=True)
test_eq(v,[1,0,5,3])
test_eq(o,{1:0, 0: 1, 5: 2, 3: 3})
v,o = uniqueify([1,1,0,5,0,3], sort=True, bidir=True)
test_eq(v,[0,1,3,5])
test_eq(o,{0:0, 1: 1, 3: 2, 5: 3})

In [None]:
# export
def setify(o): return o if isinstance(o,set) else set(L(o))

In [None]:
# test
test_eq(setify(None),set())
test_eq(setify('abc'),{'abc'})
test_eq(setify([1,2,2]),{1,2})
test_eq(setify(range(0,3)),{0,1,2})
test_eq(setify({1,2}),{1,2})

In [None]:
#export
def is_listy(x):
    "`isinstance(x, (tuple,list,L))`"
    return isinstance(x, (tuple,list,L,slice,Generator))

In [None]:
assert is_listy([1])
assert is_listy(L([1]))
assert is_listy(slice(2))
assert not is_listy(array([1]))

In [None]:
#export
def range_of(x):
    "All indices of collection `x` (i.e. `list(range(len(x)))`)"
    return list(range(len(x)))

In [None]:
test_eq(range_of([1,1,1,1]), [0,1,2,3])

In [None]:
#export
def groupby(x, key):
    "Like `itertools.groupby` but doesn't need to be sorted, and isn't lazy"
    res = {}
    for o in x: res.setdefault(key(o), []).append(o)
    return res

In [None]:
test_eq(groupby('aa ab bb'.split(), itemgetter(0)), {'a':['aa','ab'], 'b':['bb']})

In [None]:
#export
def merge(*ds):
    "Merge all dictionaries in `ds`"
    return {k:v for d in ds if d is not None for k,v in d.items()}

In [None]:
test_eq(merge(), {})
test_eq(merge(dict(a=1,b=2)), dict(a=1,b=2))
test_eq(merge(dict(a=1,b=2), dict(b=3,c=4), None), dict(a=1, b=3, c=4))

In [None]:
#export
def shufflish(x, pct=0.04):
    "Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
    n = len(x)
    return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))

In [None]:
l = list(range(100))
l2 = array(shufflish(l))
test_close(l2[:50 ].mean(), 25, eps=5)
test_close(l2[-50:].mean(), 75, eps=5)
test_ne(l,l2)

In [None]:
#export
class IterLen:
    "Base class to add iteration to anything supporting `len` and `__getitem__`"
    def __iter__(self): return (self[i] for i in range_of(self))

In [None]:
#export
@docs
class ReindexCollection(GetAttr, IterLen):
    "Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
    _default='coll'
    def __init__(self, coll, idxs=None, cache=None):
        self.coll,self.idxs,self.cache = coll,ifnone(idxs,L.range(coll)),cache
        def _get(self, i): return self.coll[i]
        self._get = types.MethodType(_get,self)
        if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)

    def __getitem__(self, i): return self._get(self.idxs[i])
    def __len__(self): return len(self.coll)
    def reindex(self, idxs): self.idxs = idxs
    def shuffle(self): random.shuffle(self.idxs)
    def cache_clear(self): self._get.cache_clear()

    _docs = dict(reindex="Replace `self.idxs` with idxs",
                shuffle="Randomly shuffle indices",
                cache_clear="Clear LRU cache")

In [None]:
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
test_eq(list(t), range(sz))
test_eq(t[sz-1], sz-1)
test_eq(t._get.cache_info().hits, 1)
t.shuffle()
test_eq(t._get.cache_info().hits, 1)
test_ne(list(t), range(sz))
test_eq(set(t), set(range(sz)))
t.cache_clear()
test_eq(t._get.cache_info().hits, 0)
test_eq(t.count(0), 1)

In [None]:
#export
def _oper(op,a,b=None): return (lambda o:op(o,a)) if b is None else op(a,b)

def _mk_op(nm, mod=None):
    "Create an operator using `oper` and add to the caller's module"
    if mod is None: mod = inspect.currentframe().f_back.f_locals
    op = getattr(operator,nm)
    def _inner(a,b=None): return _oper(op, a,b)
    _inner.__name__ = _inner.__qualname__ = nm
    _inner.__doc__ = f'Same as `operator.{nm}`, or returns partial if 1 arg'
    mod[nm] = _inner

In [None]:
#export
_all_ = ['lt', 'gt', 'le', 'ge', 'eq', 'ne', 'add', 'sub', 'mul', 'truediv']

In [None]:
#export
for op in 'lt gt le ge eq ne add sub mul truediv'.split(): _mk_op(op)

The following functions are provided matching the behavior of the equivalent versions in `operator`:

 - *lt gt le ge eq ne add sub mul truediv*

In [None]:
lt(3,5),gt(3,5)

(True, False)

However, they also have additional functionality: if you only pass one param, they return a partial function that passes that param as the second positional parameter.

In [None]:
lt(5)(3),gt(5)(3)

(True, False)

In [None]:
#export
class _InfMeta(type):
    @property
    def count(self): return itertools.count()
    @property
    def zeros(self): return itertools.cycle([0])
    @property
    def ones(self):  return itertools.cycle([1])
    @property
    def nones(self): return itertools.cycle([None])

In [None]:
#export
class Inf(metaclass=_InfMeta):
    "Infinite lists"
    pass

`Inf` defines the following properties:
    
- `count: itertools.count()`
- `zeros: itertools.cycle([0])`
- `ones : itertools.cycle([1])`
- `nones: itertools.cycle([None])`

In [None]:
test_eq([o for i,o in zip(range(5), Inf.count)],
        [0, 1, 2, 3, 4])

test_eq([o for i,o in zip(range(5), Inf.zeros)],
        [0, 0, 0, 0, 0])

In [None]:
#export
def true(*args, **kwargs):
    "Predicate: always `True`"
    return True

In [None]:
#export
def stop(e=StopIteration):
    "Raises exception `e` (by default `StopException`) even if in an expression"
    raise e

In [None]:
#export
def gen(func, seq, cond=true):
    "Like `(func(o) for o in seq if cond(func(o)))` but handles `StopIteration`"
    return itertools.takewhile(cond, map(func,seq))

In [None]:
test_eq(gen(noop, Inf.count, lt(5)),
        range(5))
test_eq(gen(operator.neg, Inf.count, gt(-5)),
        [0,-1,-2,-3,-4])
test_eq(gen(lambda o:o if o<5 else stop(), Inf.count),
        range(5))

In [None]:
#export
def chunked(it, cs, drop_last=False):
    if not isinstance(it, Iterator): it = iter(it)
    while True:
        res = list(itertools.islice(it, cs))
        if res and (len(res)==cs or not drop_last): yield res
        if len(res)<cs: return

In [None]:
t = L.range(10)
test_eq(chunked(t,3),      [[0,1,2], [3,4,5], [6,7,8], [9]])
test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8],    ])

t = map(lambda o:stop() if o==6 else o, Inf.count)
test_eq(chunked(t,3), [[0, 1, 2], [3, 4, 5]])
t = map(lambda o:stop() if o==7 else o, Inf.count)
test_eq(chunked(t,3), [[0, 1, 2], [3, 4, 5], [6]])

t = np.arange(10)
test_eq(chunked(t,3),      L([0,1,2], [3,4,5], [6,7,8], [9]))
test_eq(chunked(t,3,True), L([0,1,2], [3,4,5], [6,7,8],    ))

In [None]:
#export
def retain_type(new, old=None, typ=None):
    "Cast `new` to type of `old` if it's a superclass"
    # e.g. old is TensorImage, new is Tensor - if not subclass then do nothing
    if new is None: return new
    assert old is not None or typ is not None
    if typ is None:
        if not isinstance(old, type(new)): return new
        typ = old if isinstance(old,type) else type(old)
    # Do nothing the new type is already an instance of requested type (i.e. same type)
    return typ(new) if typ!=NoneType and not isinstance(new, typ) else new

In [None]:
class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
test_eq_type(retain_type(b, typ=_T), a)

In [None]:
#export
def retain_types(new, old=None, typs=None):
    "Cast each item of `new` to type of matching item in `old` if it's a superclass"
    if not is_listy(new): return retain_type(new, old, typs)
    return type(new)(L(new, old, typs).map_zip(retain_type, cycled=True))

In [None]:
class T(tuple): pass

t1,t2 = retain_types((1,(1,)), (2,T((2,))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,)))

In [None]:
#export
@patch
def split_arr(df:pd.DataFrame, from_col):
    "Split col `from_col` (containing arrays) in `DataFrame` `df` into separate colums"
    col = df[from_col]
    n = len(col.iloc[0])
    cols = [f'{from_col}{o}' for o in range(n)]
    df[cols] = pd.DataFrame(df[from_col].values.tolist())
    df.drop(columns=from_col, inplace=True)

In [None]:
df = pd.DataFrame(dict(a=[[1,2,3],[4,5,6]]))
df.split_arr('a')
test_eq(df, pd.DataFrame(dict(a0=[1,4],a1=[2,5],a2=[3,6])))

## Simple types

In [None]:
#export
def show_title(o, ax=None, ctx=None, label=None, color='black', **kwargs):
    "Set title of `ax` to `o`, or print `o` if `ax` is `None`"
    ax = ifnone(ax,ctx)
    if ax is None: print(o)
    elif hasattr(ax, 'set_title'): 
        t = ax.title.get_text()
        if len(t) > 0: o = t+'\n'+str(o)
        ax.set_title(o, color=color)
    elif isinstance(ax, pd.Series):
        while label in ax: label += '_'
        ax = ax.append(pd.Series({label: o}))
    return ax

In [None]:
test_stdout(lambda: show_title("title"), "title")
# ensure that col names are unique when showing to a pandas series
assert show_title("title", ctx=pd.Series(dict(a=1)), label='a').equals(pd.Series(dict(a=1,a_='title')))

In [None]:
#export
class ShowTitle:
    "Base class that adds a simple `show`"
    _show_args = {'label': 'text'}
    def show(self, ctx=None, **kwargs): return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))

class Int(int, ShowTitle): pass
class Float(float, ShowTitle): pass
class Str(str, ShowTitle): pass
add_docs(Int, "An `int` with `show`"); add_docs(Str, "An `str` with `show`"); add_docs(Float, "An `float` with `show`")

In [None]:
show_doc(Int, title_level=3)

<h3 id="Int" class="doc_header"><code>class</code> <code>Int</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>Int</code>() :: `int`

An `int` with [`show`](/medical.imaging.html#show)

In [None]:
show_doc(Str, title_level=3)

<h3 id="Str" class="doc_header"><code>class</code> <code>Str</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>Str</code>() :: `str`

An `str` with [`show`](/medical.imaging.html#show)

In [None]:
show_doc(Float, title_level=3)

<h3 id="Float" class="doc_header"><code>class</code> <code>Float</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>Float</code>(**`x`**=*`0`*) :: `float`

An `float` with [`show`](/medical.imaging.html#show)

In [None]:
test_stdout(lambda: Str('s').show(), 's')
test_stdout(lambda: Int(1).show(), '1')

In [None]:
#export
num_methods = """
    __add__ __sub__ __mul__ __matmul__ __truediv__ __floordiv__ __mod__ __divmod__ __pow__
    __lshift__ __rshift__ __and__ __xor__ __or__ __neg__ __pos__ __abs__
""".split()
rnum_methods = """
    __radd__ __rsub__ __rmul__ __rmatmul__ __rtruediv__ __rfloordiv__ __rmod__ __rdivmod__
    __rpow__ __rlshift__ __rrshift__ __rand__ __rxor__ __ror__
""".split()
inum_methods = """
    __iadd__ __isub__ __imul__ __imatmul__ __itruediv__
    __ifloordiv__ __imod__ __ipow__ __ilshift__ __irshift__ __iand__ __ixor__ __ior__ 
""".split()

In [None]:
#export
class Tuple(tuple):
    "A `tuple` with elementwise ops and more friendly __init__ behavior"
    def __new__(cls, x=None, *rest):
        if x is None: x = ()
        if not isinstance(x,tuple):
            if len(rest): x = (x,)
            else:
                try: x = tuple(iter(x))
                except TypeError: x = (x,)
        return super().__new__(cls, x+rest if rest else x)

    def _op(self,op,*args):
        if not isinstance(self,Tuple): self = Tuple(self)
        return type(self)(map(op,self,*map(cycle, args)))
    
    def mul(self,*args):
        "`*` is already defined in `tuple` for replicating, so use `mul` instead"
        return Tuple._op(self, operator.mul,*args)
    
    def add(self,*args):
        "`+` is already defined in `tuple` for concat, so use `add` instead"
        return Tuple._op(self, operator.add,*args)

def _get_op(op):
    if isinstance(op,str): op = getattr(operator,op)
    def _f(self,*args): return self._op(op,*args)
    return _f

for n in num_methods:
    if not hasattr(Tuple, n) and hasattr(operator,n): setattr(Tuple,n,_get_op(n))

for n in 'eq ne lt le gt ge'.split(): setattr(Tuple,n,_get_op(n))
setattr(Tuple,'__invert__',_get_op('__not__'))
setattr(Tuple,'max',_get_op(max))
setattr(Tuple,'min',_get_op(min))

In [None]:
test_eq(Tuple(1), (1,))
test_eq(type(Tuple(1)), Tuple)
test_eq_type(Tuple(1,2), Tuple(1,2))
test_ne(Tuple(1,2), Tuple(1,3))
test_eq(Tuple(), ())
test_eq(Tuple((1,2)), (1,2))
test_eq(-Tuple(1,2), (-1,-2))
test_eq(Tuple(1,1)-Tuple(2,2), (-1,-1))
test_eq(Tuple.add((1,1),(2,2)), (3,3))
test_eq(Tuple(1,1).add((2,2)), Tuple(3,3))
test_eq(Tuple('1','2').add('2'), Tuple('12','22'))
test_eq_type(Tuple(1,1).add(2), Tuple(3,3))
test_eq_type(Tuple(1,1).mul(2), Tuple(2,2))
test_eq(Tuple(3,1).le(1), (False, True))
test_eq(Tuple(3,1).eq(1), (False, True))
test_eq(Tuple(3,1).gt(1), (True, False))
test_eq(Tuple(3,1).min(2), (2,1))
test_eq(~Tuple(1,0,1), (False,True,False))

In [None]:
#export
class TupleTitled(Tuple, ShowTitle):
    "A `Tuple` with `show`"
    pass

## Functions on functions

In [None]:
#export
def trace(f):
    "Add `set_trace` to an existing function `f`"
    def _inner(*args,**kwargs):
        set_trace()
        return f(*args,**kwargs)
    return _inner

In [None]:
# export
def compose(*funcs, order=None):
    "Create a function that composes all functions in `funcs`, passing along remaining `*args` and `**kwargs` to all"
    funcs = L(funcs)
    if len(funcs)==0: return noop
    if len(funcs)==1: return funcs[0]
    if order is not None: funcs = funcs.sorted(order)
    def _inner(x, *args, **kwargs):
        for f in L(funcs): x = f(x, *args, **kwargs)
        return x
    return _inner

In [None]:
f1 = lambda o,p=0: (o*2)+p
f2 = lambda o,p=1: (o+1)/p
test_eq(f2(f1(3)), compose(f1,f2)(3))
test_eq(f2(f1(3,p=3),p=3), compose(f1,f2)(3,p=3))
test_eq(f2(f1(3,  3),  3), compose(f1,f2)(3,  3))

f1.order = 1
test_eq(f1(f2(3)), compose(f1,f2, order="order")(3))

In [None]:
#export
def maps(*args, retain=noop):
    "Like `map`, except funcs are composed first"
    f = compose(*args[:-1])
    def _f(b): return retain(f(b), b)
    return map(_f, args[-1])

In [None]:
test_eq(maps([1]), [1])
test_eq(maps(operator.neg, [1,2]), [-1,-2])
test_eq(maps(operator.neg, operator.neg, [1,2]), [1,2])

test_eq_type(list(maps(operator.neg, [Tuple((1,)), 2], retain=retain_type)), 
             [Tuple((-1,)), -2])

In [None]:
#export
def partialler(f, *args, order=None, **kwargs):
    "Like `functools.partial` but also copies over docstring"
    fnew = partial(f,*args,**kwargs)
    fnew.__doc__ = f.__doc__
    if order is not None: fnew.order=order
    elif hasattr(f,'order'): fnew.order=f.order
    return fnew

In [None]:
def _f(x,a=1):
    "test func"
    return x+a
_f.order=1

f = partialler(_f, a=2)
test_eq(f.order, 1)
f = partialler(_f, a=2, order=3)
test_eq(f.__doc__, "test func")
test_eq(f.order, 3)
test_eq(f(3), _f(3,2))

In [None]:
#export
def mapped(f, it):
    "map `f` over `it`, unless it's not listy, in which case return `f(it)`"
    return L(it).map(f) if is_listy(it) else f(it)

In [None]:
test_eq(mapped(_f,1),2)
test_eq(mapped(_f,[1,2]),[2,3])
test_eq(mapped(_f,(1,)),(2,))

In [None]:
#export
def instantiate(t):
    "Instantiate `t` if it's a type, otherwise do nothing"
    return t() if isinstance(t, type) else t

In [None]:
test_eq_type(instantiate(int), 0)
test_eq_type(instantiate(1), 1)

In [None]:
#export
class _Self:
    "An alternative to `lambda` for calling methods on passed object."
    def __init__(self): self.nms,self.args,self.kwargs,self.ready = [],[],[],True
    def __repr__(self): return f'self: {self.nms}({self.args}, {self.kwargs})'

    def __call__(self, *args, **kwargs):
        if self.ready:
            x = args[0]
            for n,a,k in zip(self.nms,self.args,self.kwargs):
                x = getattr(x,n)
                if callable(x) and a is not None: x = x(*a, **k)
            return x
        else:
            self.args.append(args)
            self.kwargs.append(kwargs)
            self.ready = True
            return self

    def __getattr__(self,k):
        if not self.ready:
            self.args.append(None)
            self.kwargs.append(None)
        self.nms.append(k)
        self.ready = False
        return self

class _SelfCls:
    def __getattr__(self,k): return getattr(_Self(),k)

Self = _SelfCls()

In [None]:
#export
_all_ = ['Self']

### Self

fastai provides a concise way to create lambdas that are calling methods on an object, which is to use `Self` (note the capitalization!) `Self.sum()`, for instance, is a shortcut for `lambda o: o.sum()`.

In [None]:
f = Self.sum()
x = array([3.,1])
test_eq(f(x), 4.)

# This is equivalent to above
f = lambda o: o.sum()
x = array([3.,1])
test_eq(f(x), 4.)

f = Self.sum().is_integer()
x = array([3.,1])
test_eq(f(x), True)

f = Self.sum().real.is_integer()
x = array([3.,1])
test_eq(f(x), True)

f = Self.imag()
test_eq(f(3), 0)

## File and network functions

In [None]:
#export
@patch
def readlines(self:Path, hint=-1):
    "Read the content of `fname`"
    with self.open() as f: return f.readlines(hint)

In [None]:
#export
@patch
def read(self:Path, size=-1):
    "Read the content of `fname`"
    with self.open() as f: return f.read(size)

In [None]:
#export
@patch
def write(self:Path, txt):
    "Write `txt` to `self`, creating directories as needed"
    self.parent.mkdir(parents=True,exist_ok=True)
    with self.open('w') as f: f.write(txt)

In [None]:
with tempfile.NamedTemporaryFile() as f:
    fn = Path(f.name)
    fn.write('t')
    t = fn.read()
    test_eq(t,'t')
    t = fn.readlines()
    test_eq(t,['t'])

In [None]:
#export
@patch
def save(fn:Path, o):
    "Save a pickle file, to a file name or opened file"
    if not isinstance(fn, io.IOBase): fn = open(fn,'wb')
    try: pickle.dump(o, fn)
    finally: fn.close()

In [None]:
#export
@patch
def load(fn:Path):
    "Load a pickle file from a file name or opened file"
    if not isinstance(fn, io.IOBase): fn = open(fn,'rb')
    try: return pickle.load(fn)
    finally: fn.close()

In [None]:
with tempfile.NamedTemporaryFile() as f:
    fn = Path(f.name)
    fn.save('t')
    t = fn.load()
test_eq(t,'t')

In [None]:
#export
#NB: Please don't move this to a different line or module, since it's used in testing `get_source_link`
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
    "Contents of path as a list"
    extns=L(file_exts)
    if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))
    has_extns = len(extns)==0
    res = (o for o in self.iterdir() if has_extns or o.suffix in extns)
    if n_max is not None: res = itertools.islice(res, n_max)
    return L(res)

We add an `ls()` method to `pathlib.Path` which is simply defined as `list(Path.iterdir())`, mainly for convenience in REPL environments such as notebooks.

In [None]:
path = Path()
t = path.ls()
assert len(t)>0
t1 = path.ls(10)
test_eq(len(t1), 10)
t2 = path.ls(file_exts='.ipynb')
assert len(t)>len(t2)
t[0]

PosixPath('18_callback_fp16.ipynb')

You can also pass an optional `file_type` MIME prefix and/or a list of file extensions.

In [None]:
txt_files=path.ls(file_type='text')
assert len(txt_files) > 0 and txt_files[0].suffix=='.py'
ipy_files=path.ls(file_exts=['.ipynb'])
assert len(ipy_files) > 0 and ipy_files[0].suffix=='.ipynb'
txt_files[0],ipy_files[0]

(PosixPath('train_wt2.py'), PosixPath('18_callback_fp16.ipynb'))

In [None]:
#hide
pkl = pickle.dumps(path)
p2 =pickle.loads(pkl)
test_eq(path.ls()[0], p2.ls()[0])

In [None]:
#export
def bunzip(fn):
    "bunzip `fn`, raising exception if output already exists"
    fn = Path(fn)
    assert fn.exists(), f"{fn} doesn't exist"
    out_fn = fn.with_suffix('')
    assert not out_fn.exists(), f"{out_fn} already exists"
    with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
        for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)

In [None]:
f = Path('files/test.txt')
if f.exists(): f.unlink()
bunzip('files/test.txt.bz2')
t = f.open().readlines()
test_eq(len(t),1)
test_eq(t[0], 'test\n')
f.unlink()

In [None]:
#export
def join_path_file(file, path, ext=''):
    "Return `path/file` if file is a string or a `Path`, file otherwise"
    if not isinstance(file, (str, Path)): return file
    path.mkdir(parents=True, exist_ok=True)
    return path/f'{file}{ext}'

In [None]:
path = Path.cwd()/'_tmp'/'tst'
f = join_path_file('tst.txt', path)
assert path.exists()
test_eq(f, path/'tst.txt')
with open(f, 'w') as f_: assert join_path_file(f_, path) == f_
shutil.rmtree(Path.cwd()/'_tmp')

## Sorting objects from before/after

Transforms and callbacks will have run_after/run_before attributes, this function will sort them to respect those requirements (if it's possible). Also, sometimes we want a tranform/callback to be run at the end, but still be able to use run_after/run_before behaviors. For those, the function checks for a toward_end attribute (that needs to be True).

In [None]:
#export
def _is_instance(f, gs):
    tst = [g if type(g) in [type, 'function'] else g.__class__ for g in gs]
    for g in tst:
        if isinstance(f, g) or f==g: return True
    return False

def _is_first(f, gs):
    for o in L(getattr(f, 'run_after', None)):
        if _is_instance(o, gs): return False
    for g in gs:
        if _is_instance(f, L(getattr(g, 'run_before', None))): return False
    return True

def sort_by_run(fs):
    end = L(fs).attrgot('toward_end')
    inp,res = L(fs)[~end] + L(fs)[end], L()
    while len(inp):
        for i,o in enumerate(inp):
            if _is_first(o, inp):
                res.append(inp.pop(i))
                break
        else: raise Exception("Impossible to sort")
    return res

In [None]:
class Tst(): pass    
class Tst1():
    run_before=[Tst]
class Tst2():
    run_before=Tst
    run_after=Tst1
    
tsts = [Tst(), Tst1(), Tst2()]
test_eq(sort_by_run(tsts), [tsts[1], tsts[2], tsts[0]])

Tst2.run_before,Tst2.run_after = Tst1,Tst
test_fail(lambda: sort_by_run([Tst(), Tst1(), Tst2()]))

def tst1(x): return x
tst1.run_before = Tst
test_eq(sort_by_run([tsts[0], tst1]), [tst1, tsts[0]])
    
class Tst1():
    toward_end=True
class Tst2():
    toward_end=True
    run_before=Tst1
tsts = [Tst(), Tst1(), Tst2()]
test_eq(sort_by_run(tsts), [tsts[0], tsts[2], tsts[1]])

## Image helpers

In [None]:
#export
@delegates(plt.subplots, keep=True)
def subplots(nrows=1, ncols=1, figsize=None, imsize=4, **kwargs):
    if figsize is None: figsize=(imsize*ncols,imsize*nrows)
    fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
    if nrows*ncols==1: ax = array([ax])
    return fig,ax

In [None]:
#hide
_,axs = subplots()
test_eq(axs.shape,[1])
plt.close()
_,axs = subplots(2,3)
test_eq(axs.shape,[2,3])
plt.close()

## Other helpers

In [None]:
#export core
def display_df(df):
    "Display `df` in a notebook or defaults to print"
    try: from IPython.display import display, HTML
    except: return print(df)
    display(HTML(df.to_html()))

In [None]:
#export
def round_multiple(x, mult, round_down=False):
    "Round `x` to nearest multiple of `mult`"
    def _f(x_): return (int if round_down else round)(x_/mult)*mult
    res = L(x).map(_f)
    return res if is_listy(x) else res[0]

In [None]:
test_eq(round_multiple(63,32), 64)
test_eq(round_multiple(50,32), 64)
test_eq(round_multiple(40,32), 32)
test_eq(round_multiple( 0,32),  0)
test_eq(round_multiple(63,32, round_down=True), 32)
test_eq(round_multiple((63,40),32), (64,32))

In [None]:
#export
def even_mults(start, stop, n):
    "Build log-stepped array from `start` to `stop` in `n` steps."
    if n==1: return stop
    mult = stop/start
    step = mult**(1/(n-1))
    return np.array([start*(step**i) for i in range(n)])

In [None]:
test_eq(even_mults(2,8,3), [2,4,8])
test_eq(even_mults(2,32,5), [2,4,8,16,32])
test_eq(even_mults(2,8,1), 8)

In [None]:
#export
def num_cpus():
    "Get number of cpus"
    try:                   return len(os.sched_getaffinity(0))
    except AttributeError: return os.cpu_count()

defaults.cpus = num_cpus()

In [None]:
#export
def add_props(f, n=2):
    "Create properties passing each of `range(n)` to f"
    return (property(partial(f,i)) for i in range(n))

In [None]:
class _T(): a,b = add_props(lambda i,x:i*2)

t = _T()
test_eq(t.a,0)
test_eq(t.b,2)

# Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_utils.ipynb.
Converted 01b_dispatch.ipynb.
Converted 01c_torch_core.ipynb.
Converted 02_script.ipynb.
Converted 03_dataloader.ipynb.
Converted 04_transform.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 10_data_block.ipynb.
Converted 11_layers.ipynb.
Converted 11a_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 14a_callback_data.ipynb.
Converted 15_callback_hook.ipynb.
Converted 15a_vision_models_unet.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_vision_learner.ipynb.
Converted 22_tutorial_imagenette.ipynb.
Converted 23_tutorial_transfer_learning.ipynb.
Conver