In [None]:
#default_exp xtras

In [None]:
#export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.basics import *
from functools import wraps

import mimetypes,bz2,pickle,random,json,urllib,subprocess,shlex,bz2,gzip,distutils.util,imghdr,struct
from contextlib import contextmanager,ExitStack
from pdb import set_trace
from urllib.request import Request,urlopen
from urllib.error import HTTPError
from urllib.parse import urlencode
from threading import Thread

In [None]:
from fastcore.test import *
from nbdev.showdoc import *
from fastcore.nb_imports import *

# Utility functions

> Utility functions used in the fastai library

## Collections

In [None]:
#export
def dict2obj(d):
    "Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`"
    if isinstance(d, (L,list)): return L(d).map(dict2obj)
    if not isinstance(d, dict): return d
    return AttrDict(**{k:dict2obj(v) for k,v in d.items()})

This is a convenience to give you "dotted" access to (possibly nested) dictionaries, e.g:

In [None]:
d1 = dict(a=1, b=dict(c=2,d=3))
d2 = dict2obj(d1)
test_eq(d2.b.c, 2)
test_eq(d2.b['c'], 2)
d2

{'a': 1, 'b': {'c': 2, 'd': 3}}

It can also be used on lists of dicts.

In [None]:
ds = L(d1, d1)
test_eq(dict2obj(ds)[0].b.c, 2)

In [None]:
#export
def tuplify(o, use_list=False, match=None):
    "Make `o` a tuple"
    return tuple(L(o, use_list=use_list, match=match))

In [None]:
test_eq(tuplify(None),())
test_eq(tuplify([1,2,3]),(1,2,3))
test_eq(tuplify(1,match=[1,2,3]),(1,1,1))

In [None]:
#export
def uniqueify(x, sort=False, bidir=False, start=None):
    "Unique elements in `x`, optionally `sort`-ed, optionally return reverse correspondence, optionally prepend with elements."
    res = L(x).unique()
    if start is not None: res = start+res
    if sort: res.sort()
    if bidir: return res, res.val2idx()
    return res

In [None]:
# test
test_eq(set(uniqueify([1,1,0,5,0,3])),{0,1,3,5})
test_eq(uniqueify([1,1,0,5,0,3], sort=True),[0,1,3,5])
test_eq(uniqueify([1,1,0,5,0,3], start=[7,8,6]), [7,8,6,1,0,5,3])
v,o = uniqueify([1,1,0,5,0,3], bidir=True)
test_eq(v,[1,0,5,3])
test_eq(o,{1:0, 0: 1, 5: 2, 3: 3})
v,o = uniqueify([1,1,0,5,0,3], sort=True, bidir=True)
test_eq(v,[0,1,3,5])
test_eq(o,{0:0, 1: 1, 3: 2, 5: 3})

In [None]:
#export
def is_listy(x):
    "`isinstance(x, (tuple,list,L,slice,Generator))`"
    return isinstance(x, (tuple,list,L,slice,Generator))

In [None]:
assert is_listy((1,))
assert is_listy([1])
assert is_listy(L([1]))
assert is_listy(slice(2))
assert not is_listy(array([1]))

In [None]:
#export
def shufflish(x, pct=0.04):
    "Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
    n = len(x)
    return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))

In [None]:
#export
def mapped(f, it):
    "map `f` over `it`, unless it's not listy, in which case return `f(it)`"
    return L(it).map(f) if is_listy(it) else f(it)

In [None]:
def _f(x,a=1): return x-a

test_eq(mapped(_f,1),0)
test_eq(mapped(_f,[1,2]),[0,1])
test_eq(mapped(_f,(1,)),(0,))

## Reindexing Collections

In [None]:
#export
#hide
class IterLen:
    "Base class to add iteration to anything supporting `__len__` and `__getitem__`"
    def __iter__(self): return (self[i] for i in range_of(self))

In [None]:
#export
@docs
class ReindexCollection(GetAttr, IterLen):
    "Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
    _default='coll'
    def __init__(self, coll, idxs=None, cache=None, tfm=noop):
        if idxs is None: idxs = L.range(coll) 
        store_attr()
        if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)

    def _get(self, i): return self.tfm(self.coll[i])
    def __getitem__(self, i): return self._get(self.idxs[i])
    def __len__(self): return len(self.coll)
    def reindex(self, idxs): self.idxs = idxs
    def shuffle(self): random.shuffle(self.idxs)
    def cache_clear(self): self._get.cache_clear()
    def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm}
    def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm']

    _docs = dict(reindex="Replace `self.idxs` with idxs",
                shuffle="Randomly shuffle indices",
                cache_clear="Clear LRU cache")

In [None]:
show_doc(ReindexCollection, title_level=4)

<h4 id="ReindexCollection" class="doc_header"><code>class</code> <code>ReindexCollection</code><a href="" class="source_link" style="float:right">[source]</a></h4>

> <code>ReindexCollection</code>(**`coll`**, **`idxs`**=*`None`*, **`cache`**=*`None`*, **`tfm`**=*`noop`*) :: [`GetAttr`](/foundation.html#GetAttr)

Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`

This is useful when constructing batches or organizing data in a particular manner (i.e. for deep learning).  This class is primarly used in organizing data for language models in fastai.

You can supply a custom index upon instantiation with the `idxs` argument, or you can call the `reindex` method to supply a new index for your collection.

Here is how you can reindex a list such that the elements are reversed:

In [None]:
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'], idxs=[4,3,2,1,0])
list(rc)

['e', 'd', 'c', 'b', 'a']

Alternatively, you can use the `reindex` method:

In [None]:
show_doc(ReindexCollection.reindex, title_level=6)

<h6 id="ReindexCollection.reindex" class="doc_header"><code>ReindexCollection.reindex</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h6>

> <code>ReindexCollection.reindex</code>(**`idxs`**)

Replace `self.idxs` with idxs

In [None]:
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'])
rc.reindex([4,3,2,1,0])
list(rc)

['e', 'd', 'c', 'b', 'a']

You can optionally specify a LRU cache, which uses [functools.lru_cache](https://docs.python.org/3/library/functools.html#functools.lru_cache) upon instantiation:

In [None]:
sz = 50
t = ReindexCollection(L.range(sz), cache=2)

#trigger a cache hit by indexing into the same element multiple times
t[0], t[0]
t._get.cache_info()

CacheInfo(hits=1, misses=1, maxsize=2, currsize=1)

You can optionally clear the LRU cache by calling the `cache_clear` method:

In [None]:
show_doc(ReindexCollection.cache_clear, title_level=5)

<h5 id="ReindexCollection.cache_clear" class="doc_header"><code>ReindexCollection.cache_clear</code><a href="__main__.py#L16" class="source_link" style="float:right">[source]</a></h5>

> <code>ReindexCollection.cache_clear</code>()

Clear LRU cache

In [None]:
sz = 50
t = ReindexCollection(L.range(sz), cache=2)

#trigger a cache hit by indexing into the same element multiple times
t[0], t[0]
t.cache_clear()
t._get.cache_info()

CacheInfo(hits=0, misses=0, maxsize=2, currsize=0)

In [None]:
show_doc(ReindexCollection.shuffle, title_level=5)

<h5 id="ReindexCollection.shuffle" class="doc_header"><code>ReindexCollection.shuffle</code><a href="__main__.py#L15" class="source_link" style="float:right">[source]</a></h5>

> <code>ReindexCollection.shuffle</code>()

Randomly shuffle indices

Note that an ordered index is automatically constructed for the data structure even if one is not supplied.

In [None]:
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])
rc.shuffle()
list(rc)

['a', 'd', 'c', 'b', 'e', 'g', 'f', 'h']

In [None]:
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
test_eq(list(t), range(sz))
test_eq(t[sz-1], sz-1)
test_eq(t._get.cache_info().hits, 1)
t.shuffle()
test_eq(t._get.cache_info().hits, 1)
test_ne(list(t), range(sz))
test_eq(set(t), set(range(sz)))
t.cache_clear()
test_eq(t._get.cache_info().hits, 0)
test_eq(t.count(0), 1)

In [None]:
#hide
#Test ReindexCollection pickles
t1 = pickle.loads(pickle.dumps(t))
test_eq(list(t), list(t1))

## Extensions to Pathlib.Path

An extension of the standard python libary [Pathlib.Path](https://docs.python.org/3/library/pathlib.html#basic-use).  These extensions are accomplished by monkey patching additional methods onto `Pathlib.Path`.

In [None]:
#export
@patch
def readlines(self:Path, hint=-1, encoding='utf8'):
    "Read the content of `self`"
    with self.open(encoding=encoding) as f: return f.readlines(hint)

In [None]:
#export
@patch
def mk_write(self:Path, data, encoding=None, errors=None, mode=511):
    "Make all parent dirs of `self`"
    self.parent.mkdir(exist_ok=True, parents=True, mode=mode)
    self.write_text(data, encoding=encoding, errors=errors)

In [None]:
#export
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
    "Contents of path as a list"
    extns=L(file_exts)
    if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))
    has_extns = len(extns)==0
    res = (o for o in self.iterdir() if has_extns or o.suffix in extns)
    if n_max is not None: res = itertools.islice(res, n_max)
    return L(res)

We add an `ls()` method to `pathlib.Path` which is simply defined as `list(Path.iterdir())`, mainly for convenience in REPL environments such as notebooks.

In [None]:
path = Path()
t = path.ls()
assert len(t)>0
t1 = path.ls(10)
test_eq(len(t1), 10)
t2 = path.ls(file_exts='.ipynb')
assert len(t)>len(t2)
t[0]

Path('.ipynb_checkpoints')

You can also pass an optional `file_type` MIME prefix and/or a list of file extensions.

In [None]:
lib_path = (path/'../fastcore')
txt_files=lib_path.ls(file_type='text')
assert len(txt_files) > 0 and txt_files[0].suffix=='.py'
ipy_files=path.ls(file_exts=['.ipynb'])
assert len(ipy_files) > 0 and ipy_files[0].suffix=='.ipynb'
txt_files[0],ipy_files[0]

(Path('../fastcore/logargs.py'), Path('00_test.ipynb'))

In [None]:
#hide
path = Path()
pkl = pickle.dumps(path)
p2 = pickle.loads(pkl)
test_eq(path.ls()[0], p2.ls()[0])

In [None]:
#export
def open_file(fn, mode='r'):
    "Open a file, with optional compression if gz or bz2 suffix"
    if isinstance(fn, io.IOBase): return fn
    fn = Path(fn)
    if   fn.suffix=='.bz2': return bz2.BZ2File(fn, mode)
    elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode)
    else: return open(fn,mode)

In [None]:
#export
def save_pickle(fn, o):
    "Save a pickle file, to a file name or opened file"
    with open_file(fn, 'wb') as f: pickle.dump(o, f)

In [None]:
#export
def load_pickle(fn):
    "Load a pickle file from a file name or opened file"
    with open_file(fn, 'rb') as f: return pickle.load(f)

In [None]:
for suf in '.pkl','.bz2','.gz':
    with tempfile.NamedTemporaryFile(suffix=suf) as f:
        fn = Path(f.name)
        save_pickle(fn, 't')
        t = load_pickle(fn)
    test_eq(t,'t')

In [None]:
#export
@patch
def __repr__(self:Path):
    b = getattr(Path, 'BASE_PATH', None)
    if b:
        try: self = self.relative_to(b)
        except: pass
    return f"Path({self.as_posix()!r})"

fastai also updates the `repr` of `Path` such that, if `Path.BASE_PATH` is defined, all paths are printed relative to that path (as long as they are contained in `Path.BASE_PATH`:

In [None]:
t = ipy_files[0].absolute()
try:
    Path.BASE_PATH = t.parent.parent
    test_eq(repr(t), f"Path('nbs/{t.name}')")
finally: Path.BASE_PATH = None

## File Functions

Utilities (other than extensions to Pathlib.Path) for dealing with IO.

In [None]:
# export
@contextmanager
def maybe_open(f, mode='r', **kwargs):
    "Context manager: open `f` if it is a path (and close on exit)"
    ispath = isinstance(f, (str,os.PathLike))
    if ispath: f = open(f, mode, **kwargs)
    try: yield f
    finally:
        if ispath: f.close()

This is useful for functions where you want to accept a path *or* file. `maybe_open` will not close your file handle if you pass one in.

In [None]:
def _f(fn):
    with maybe_open(fn) as f: return f.encoding

fname = '00_test.ipynb'
test_eq(_f(fname), 'UTF-8')
with open(fname) as fh: test_eq(_f(fh), 'UTF-8')

For example, we can use this to reimplement [`imghdr.what`](https://docs.python.org/3/library/imghdr.html#imghdr.what) from the Python standard library, which is [written in Python 3.9](https://github.com/python/cpython/blob/3.9/Lib/imghdr.py#L11) as:

In [None]:
def what(file, h=None):
    f = None
    try:
        if h is None:
            if isinstance(file, (str,os.PathLike)):
                f = open(file, 'rb')
                h = f.read(32)
            else:
                location = file.tell()
                h = file.read(32)
                file.seek(location)
        for tf in imghdr.tests:
            res = tf(h, f)
            if res: return res
    finally:
        if f: f.close()
    return None

Here's an example of the use of this function:

In [None]:
fname = 'images/puppy.jpg'
what(fname)

'jpeg'

With `maybe_open`, `Self`, and `L.map_first`, we can rewrite this in a much more concise and (in our opinion) clear way:

In [None]:
def what(file, h=None):
    if h is None:
        with maybe_open(file, 'rb') as f: h = f.peek(32)
    return L(imghdr.tests).map_first(Self(h,file))

...and we can check that it still works:

In [None]:
test_eq(what(fname), 'jpeg')

...along with the version passing a file handle:

In [None]:
with open(fname,'rb') as f: test_eq(what(f), 'jpeg')

...along with the `h` parameter version:

In [None]:
with open(fname,'rb') as f: test_eq(what(None, h=f.read(32)), 'jpeg')

In [None]:
def _jpg_size(f):
        size,ftype = 2,0
        while not 0xc0 <= ftype <= 0xcf:
            f.seek(size, 1)
            byte = f.read(1)
            while ord(byte) == 0xff: byte = f.read(1)
            ftype = ord(byte)
            size = struct.unpack('>H', f.read(2))[0] - 2
        f.seek(1, 1)  # `precision'
        h,w = struct.unpack('>HH', f.read(4))
        return w,h

def _gif_size(f): return struct.unpack('<HH', head[6:10])

def _png_size(f):
    assert struct.unpack('>i', head[4:8])[0]==0x0d0a1a0a
    return struct.unpack('>ii', head[16:24])

In [None]:
#export
def image_size(fn):
    "Tuple of (w,h) for png, gif, or jpg; `None` otherwise"
    d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size)
    with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f)

In [None]:
test_eq(image_size(fname), (1200,803))

In [None]:
#export
def bunzip(fn):
    "bunzip `fn`, raising exception if output already exists"
    fn = Path(fn)
    assert fn.exists(), f"{fn} doesn't exist"
    out_fn = fn.with_suffix('')
    assert not out_fn.exists(), f"{out_fn} already exists"
    with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
        for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)

In [None]:
f = Path('files/test.txt')
if f.exists(): f.unlink()
bunzip('files/test.txt.bz2')
t = f.open().readlines()
test_eq(len(t),1)
test_eq(t[0], 'test\n')
f.unlink()

In [None]:
#export
def join_path_file(file, path, ext=''):
    "Return `path/file` if file is a string or a `Path`, file otherwise"
    if not isinstance(file, (str, Path)): return file
    path.mkdir(parents=True, exist_ok=True)
    return path/f'{file}{ext}'

In [None]:
path = Path.cwd()/'_tmp'/'tst'
f = join_path_file('tst.txt', path)
assert path.exists()
test_eq(f, path/'tst.txt')
with open(f, 'w') as f_: assert join_path_file(f_, path) == f_
shutil.rmtree(Path.cwd()/'_tmp')

In [None]:
#export
def urlread(url, data=None, **kwargs):
    "Retrieve `url`, using `data` dict or `kwargs` to `POST` if present"
    if kwargs and not data: data=kwargs
    if data is not None:
        if not isinstance(data, (str,bytes)): data = urlencode(data)
        if not isinstance(data, bytes): data = data.encode('ascii')
    cls = urllib.request.Request
    if not isinstance(url,cls): url = cls(url)
    url.headers['User-Agent'] = 'Mozilla/5.0'
    with urlopen(url, data=data) as res: return res.read()

In [None]:
#export
def urljson(url, data=None):
    "Retrieve `url` and decode json"
    return json.loads(urlread(url, data=data))

In [None]:
#export
def run(cmd, *rest, ignore_ex=False, as_bytes=False):
    "Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails"
    if rest: cmd = (cmd,)+rest
    elif isinstance(cmd,str): cmd = shlex.split(cmd)
    res = subprocess.run(cmd, capture_output=True)
    stdout = res.stdout
    if not as_bytes: stdout = stdout.decode()
    if ignore_ex: return (res.returncode, stdout)
    if res.returncode: raise IOError("{} ;; {}".format(res.stdout, res.stderr))
    return stdout

You can pass a string (which will be split based on standard shell rules), a list, or pass args directly:

In [None]:
assert 'ipynb' in run('ls -l')
assert 'ipynb' in run(['ls', '-l'])
assert 'ipynb' in run('ls', '-l')

Some commands fail in non-error situations, like `grep`. Use `ignore_ex` in those cases, which will return a tuple of stdout and returncode:

In [None]:
test_eq(run('grep asdfds 00_test.ipynb', ignore_ex=True)[0], 1)

`run` automatically decodes returned bytes to a `str`. Use `as_bytes` to skip that:

In [None]:
test_eq(run('echo hi', as_bytes=True), b'hi\n')

In [None]:
#export
def do_request(url, post=False, headers=None, **data):
    "Call GET or json-encoded POST on `url`, depending on `post`"
    if data:
        if post: data = json.dumps(data).encode('ascii')
        else:
            url += "?" + urlencode(data)
            data = None
    return urljson(Request(url, headers=headers, data=data or None))

## Other Helpers

In [None]:
#export
def _is_instance(f, gs):
    tst = [g if type(g) in [type, 'function'] else g.__class__ for g in gs]
    for g in tst:
        if isinstance(f, g) or f==g: return True
    return False

def _is_first(f, gs):
    for o in L(getattr(f, 'run_after', None)):
        if _is_instance(o, gs): return False
    for g in gs:
        if _is_instance(f, L(getattr(g, 'run_before', None))): return False
    return True

def sort_by_run(fs):
    end = L(fs).attrgot('toward_end')
    inp,res = L(fs)[~end] + L(fs)[end], L()
    while len(inp):
        for i,o in enumerate(inp):
            if _is_first(o, inp):
                res.append(inp.pop(i))
                break
        else: raise Exception("Impossible to sort")
    return res

Transforms and callbacks will have run_after/run_before attributes, this function will sort them to respect those requirements (if it's possible). Also, sometimes we want a tranform/callback to be run at the end, but still be able to use run_after/run_before behaviors. For those, the function checks for a toward_end attribute (that needs to be True).

In [None]:
class Tst(): pass    
class Tst1(): run_before=[Tst]
class Tst2():
    run_before=Tst
    run_after=Tst1
    
tsts = [Tst(), Tst1(), Tst2()]
test_eq(sort_by_run(tsts), [tsts[1], tsts[2], tsts[0]])

Tst2.run_before,Tst2.run_after = Tst1,Tst
test_fail(lambda: sort_by_run([Tst(), Tst1(), Tst2()]))

def tst1(x): return x
tst1.run_before = Tst
test_eq(sort_by_run([tsts[0], tst1]), [tst1, tsts[0]])
    
class Tst1():
    toward_end=True
class Tst2():
    toward_end=True
    run_before=Tst1
tsts = [Tst(), Tst1(), Tst2()]
test_eq(sort_by_run(tsts), [tsts[0], tsts[2], tsts[1]])

In [None]:
#export
def trace(f):
    "Add `set_trace` to an existing function `f`"
    if getattr(f, '_traced', False): return f
    def _inner(*args,**kwargs):
        set_trace()
        return f(*args,**kwargs)
    _inner._traced = True
    return _inner

You can add a breakpoint to an existing function, e.g:

```python
Path.cwd = trace(Path.cwd)
Path.cwd()
```

Now, when the function is called it will drop you into the debugger.  Note, you must issue  the `s` command when you begin to step into the function that is being traced.

In [None]:
#export
def round_multiple(x, mult, round_down=False):
    "Round `x` to nearest multiple of `mult`"
    def _f(x_): return (int if round_down else round)(x_/mult)*mult
    res = L(x).map(_f)
    return res if is_listy(x) else res[0]

In [None]:
test_eq(round_multiple(63,32), 64)
test_eq(round_multiple(50,32), 64)
test_eq(round_multiple(40,32), 32)
test_eq(round_multiple( 0,32),  0)
test_eq(round_multiple(63,32, round_down=True), 32)
test_eq(round_multiple((63,40),32), (64,32))

In [None]:
#export
@contextmanager
def modified_env(*delete, **replace):
    "Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`"
    prev = dict(os.environ)
    try:
        os.environ.update(replace)
        for k in delete: os.environ.pop(k, None)
        yield
    finally:
        os.environ.clear()
        os.environ.update(prev)

In [None]:
oldsh = os.environ['SHELL']

with modified_env('PATH', SHELL='a'):
    test_eq(os.environ['SHELL'], 'a')
    assert 'PATH' not in os.environ

assert 'PATH' in os.environ
test_eq(os.environ['SHELL'], oldsh)

In [None]:
#export
class ContextManagers(GetAttr):
    "Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
    def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
    def __enter__(self): self.default.map(self.stack.enter_context)
    def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)

In [None]:
show_doc(ContextManagers, title_level=4)

<h4 id="ContextManagers" class="doc_header"><code>class</code> <code>ContextManagers</code><a href="" class="source_link" style="float:right">[source]</a></h4>

> <code>ContextManagers</code>(**`mgrs`**) :: [`GetAttr`](/foundation.html#GetAttr)

Wrapper for `contextlib.ExitStack` which enters a collection of context managers

In [None]:
#export
def str2bool(s):
    "Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
    if not isinstance(s,str): return bool(s)
    return bool(distutils.util.strtobool(s)) if s else False

In [None]:
for o in "y YES t True on 1".split(): assert str2bool(o)
for o in "n no FALSE off 0".split(): assert not str2bool(o)
for o in 0,None,'',False: assert not str2bool(o)
for o in 1,True: assert str2bool(o)

## Multiprocessing

In [None]:
#export
from multiprocessing import Process, Queue
import concurrent.futures
import time
from multiprocessing import Manager

In [None]:
#export
def set_num_threads(nt):
    "Get numpy (and others) to use `nt` threads"
    try: import mkl; mkl.set_num_threads(nt)
    except: pass
    try: import torch; torch.set_num_threads(nt)
    except: pass
    os.environ['IPC_ENABLE']='1'
    for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
        os.environ[o] = str(nt)

This sets the number of threads consistently for many tools, by:

1. Set the following environment variables equal to `nt`: `OPENBLAS_NUM_THREADS`,`NUMEXPR_NUM_THREADS`,`OMP_NUM_THREADS`,`MKL_NUM_THREADS`
2. Sets `nt` threads for numpy and pytorch.

In [None]:
#export
def _call(lock, pause, n, g, item):
    l = False
    if pause:
        try:
            l = lock.acquire(timeout=pause*(n+2))
            time.sleep(pause)
        finally:
            if l: lock.release()
    return g(item)

In [None]:
#export
class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):
    "Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution"
    def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
        if max_workers is None: max_workers=defaults.cpus
        store_attr()
        self.not_parallel = max_workers==0
        if self.not_parallel: max_workers=1
        super().__init__(max_workers, **kwargs)

    def map(self, f, items, timeout=None, chunksize=1, *args, **kwargs):
        self.lock = Manager().Lock()
        g = partial(f, *args, **kwargs)
        if self.not_parallel: return map(g, items)
        _g = partial(_call, self.lock, self.pause, self.max_workers, g)
        try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)
        except Exception as e: self.on_exc(e)

In [None]:
#export
class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
    "Same as Python's ThreadPoolExecutor, except can pass `max_workers==0` for serial execution"
    def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
        if max_workers is None: max_workers=defaults.cpus
        store_attr()
        self.not_parallel = max_workers==0
        if self.not_parallel: max_workers=1
        super().__init__(max_workers, **kwargs)

    def map(self, f, items, timeout=None, chunksize=1, *args, **kwargs):
        self.lock = Manager().Lock()
        g = partial(f, *args, **kwargs)
        if self.not_parallel: return map(g, items)
        _g = partial(_call, self.lock, self.pause, self.max_workers, g)
        try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)
        except Exception as e: self.on_exc(e)

In [None]:
show_doc(ProcessPoolExecutor, title_level=4)

<h4 id="ProcessPoolExecutor" class="doc_header"><code>class</code> <code>ProcessPoolExecutor</code><a href="" class="source_link" style="float:right">[source]</a></h4>

> <code>ProcessPoolExecutor</code>(**`max_workers`**=*`64`*, **`on_exc`**=*`print`*, **`pause`**=*`0`*, **\*\*`kwargs`**) :: [`ProcessPoolExecutor`](/xtras.html#ProcessPoolExecutor)

Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution

In [None]:
#export
try: from fastprogress import progress_bar
except: progress_bar = None

In [None]:
#export 
def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None, pause=0,
             threadpool=False, timeout=None, chunksize=1, **kwargs):
    "Applies `func` in parallel to `items`, using `n_workers`"
    if progress is None: progress = progress_bar is not None
    pool = ThreadPoolExecutor if threadpool else ProcessPoolExecutor
    with pool(n_workers, pause=pause) as ex:
        r = ex.map(f,items, *args, timeout=timeout, chunksize=chunksize, **kwargs)
        if progress:
            if total is None: total = len(items)
            r = progress_bar(r, total=total, leave=False)
        return L(r)

In [None]:
def add_one(x, a=1): 
    time.sleep(random.random()/80)
    return x+a

inp,exp = range(50),range(1,51)
test_eq(parallel(add_one, inp, n_workers=2, progress=False), exp)
test_eq(parallel(add_one, inp, threadpool=True, n_workers=2, progress=False), exp)
test_eq(parallel(add_one, inp, n_workers=0), exp)
test_eq(parallel(add_one, inp, n_workers=1, a=2), range(2,52))
test_eq(parallel(add_one, inp, n_workers=0, a=2), range(2,52))

Use the `pause` parameter to ensure a pause of `pause` seconds between processes starting. This is in case there are race conditions in starting some process, or to stagger the time each process starts, for example when making many requests to a webserver. Set `threadpool=True` to use `ThreadPoolExecutor` instead of `ProcessPoolExecutor`.

In [None]:
from datetime import datetime

In [None]:
def print_time(i): 
    time.sleep(random.random()/1000)
    print(i, datetime.now())

parallel(print_time, range(5), n_workers=2, pause=0.25);

0 2020-10-31 13:13:41.429486
1 2020-10-31 13:13:41.679394
2 2020-10-31 13:13:41.930404
3 2020-10-31 13:13:42.181027
4 2020-10-31 13:13:42.431391


Note that `f` should accept a collection of items.

In [None]:
#export
def run_procs(f, f_done, args):
    "Call `f` for each item in `args` in parallel, yielding `f_done`"
    processes = L(args).map(Process, args=arg0, target=f)
    for o in processes: o.start()
    yield from f_done()
    processes.map(Self.join())

In [None]:
#export
def _f_pg(obj, queue, batch, start_idx):
    for i,b in enumerate(obj(batch)): queue.put((start_idx+i,b))

def _done_pg(queue, items): return (queue.get() for _ in items)

In [None]:
#export 
def parallel_gen(cls, items, n_workers=defaults.cpus, **kwargs):
    "Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel."
    if n_workers==0:
        yield from enumerate(list(cls(**kwargs)(items)))
        return
    batches = L(chunked(items, n_chunks=n_workers))
    idx = L(itertools.accumulate(0 + batches.map(len)))
    queue = Queue()
    if progress_bar: items = progress_bar(items, leave=False)
    f=partial(_f_pg, cls(**kwargs), queue)
    done=partial(_done_pg, queue, items)
    yield from run_procs(f, done, L(batches,idx).zip())

In [None]:
class _C:
    def __call__(self, o): return ((i+1) for i in o)

items = range(5)

res = L(parallel_gen(_C, items, n_workers=3))
idxs,dat1 = zip(*res.sorted(itemgetter(0)))
test_eq(dat1, range(1,6))

res = L(parallel_gen(_C, items, n_workers=0))
idxs,dat2 = zip(*res.sorted(itemgetter(0)))
test_eq(dat2, dat1)

`cls` is any class with `__call__`. It will be passed `args` and `kwargs` when initialized. Note that `n_workers` instances of `cls` are created, one in each process. `items` are then split in `n_workers` batches and one is sent to each `cls`. The function then returns a generator of tuples of item indices and results.

In [None]:
class TestSleepyBatchFunc:
    "For testing parallel processes that run at different speeds"
    def __init__(self): self.a=1
    def __call__(self, batch):
        for k in batch:
            time.sleep(random.random()/4)
            yield k+self.a

x = np.linspace(0,0.99,20)
res = L(parallel_gen(TestSleepyBatchFunc, x, n_workers=2))
test_eq(res.sorted().itemgot(1), x+1)

In [None]:
#export
def threaded(f):
    "Run `f` in a thread, and returns the thread"
    @wraps(f)
    def _f(*args, **kwargs):
        res = Thread(target=f, args=args, kwargs=kwargs)
        res.start()
        return res
    return _f

# Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_test.ipynb.
Converted 01_basics.ipynb.
Converted 02_foundation.ipynb.
Converted 03_xtras.ipynb.
Converted 04_dispatch.ipynb.
Converted 05_transform.ipynb.
Converted 06_logargs.ipynb.
Converted 07_meta.ipynb.
Converted 08_script.ipynb.
Converted index.ipynb.
