Skip to content
Switch branches/tags
Go to file
Latest commit ab15492 Apr 8, 2021 History
* set vmin to 0

* remove check at the end

* Trigger Build

Co-authored-by: Thomas Capelle <>
7 contributors

Users who have contributed to this file

@jph00 @tcapelle @marii-moe @mszhanyi @muellerzr @antorsae @albertvillanova
802 lines (682 sloc) 28.3 KB
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_torch_core.ipynb (unless otherwise specified).
__all__ = ['progress_bar', 'master_bar', 'subplots', 'show_image', 'show_titled_image', 'show_images', 'ArrayBase',
'ArrayImageBase', 'ArrayImage', 'ArrayImageBW', 'ArrayMask', 'tensor', 'set_seed', 'get_random_states',
'set_random_states', 'no_random', 'unsqueeze', 'unsqueeze_', 'apply', 'maybe_gather', 'to_detach', 'to_half',
'to_float', 'default_device', 'to_device', 'to_cpu', 'to_np', 'to_concat', 'TensorBase', 'TensorImageBase',
'TensorImage', 'TensorImageBW', 'TensorMask', 'TensorFlowField', 'TensorCategory', 'TensorMultiCategory',
'TitledTensorScalar', 'concat', 'Chunks', 'show_title', 'ShowTitle', 'TitledInt', 'TitledFloat', 'TitledStr',
'TitledTuple', 'get_empty_df', 'display_df', 'get_first', 'one_param', 'item_find', 'find_device', 'find_bs',
'np_func', 'Module', 'get_model', 'one_hot', 'one_hot_decode', 'params', 'trainable_params', 'norm_types',
'norm_bias_params', 'batch_to_samples', 'logit', 'num_distrib', 'rank_distrib', 'distrib_barrier',
'base_doc', 'doc', 'nested_reorder', 'make_cross_image', 'show_image_batch', 'requires_grad', 'init_default',
'cond_init', 'apply_leaf', 'apply_init', 'script_use_ctx', 'script_save_ctx', 'script_fwd', 'script_bwd',
'grad_module', 'flatten_check']
# Cell
from .imports import *
from .torch_imports import *
# Cell
#nbdev_comment _all_ = ['progress_bar','master_bar']
# Cell
if torch.cuda.is_available():
if torch.cuda.current_device()==0:
def_gpu = int(os.environ.get('DEFAULT_GPU') or 0)
if torch.cuda.device_count()>=def_gpu: torch.cuda.set_device(def_gpu)
torch.backends.cudnn.benchmark = True
# Cell
@delegates(plt.subplots, keep=True)
def subplots(nrows=1, ncols=1, figsize=None, imsize=3,suptitle=None, **kwargs):
if figsize is None:
h=nrows*imsize if suptitle is None or imsize>2 else nrows*imsize+0.6 #
figsize=(ncols*imsize, h)
fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
if suptitle is not None: fig.suptitle(suptitle)
if nrows*ncols==1: ax = array([ax])
return fig,ax
# Cell
def _fig_bounds(x):
r = x//32
return min(5, max(1,r))
# Cell
@delegates(plt.Axes.imshow, keep=True, but=['shape', 'imlim'])
def show_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):
"Show a PIL or PyTorch image on `ax`."
# Handle pytorch axis order
if hasattrs(im, ('data','cpu','permute')):
im =
if im.shape[0]<5: im=im.permute(1,2,0)
elif not isinstance(im,np.ndarray): im=array(im)
# Handle 1-channel images
if im.shape[-1]==1: im=im[...,0]
ax = ifnone(ax,ctx)
if figsize is None: figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1]))
if ax is None: _,ax = plt.subplots(figsize=figsize)
ax.imshow(im, **kwargs)
if title is not None: ax.set_title(title)
return ax
# Cell
@delegates(show_image, keep=True)
def show_titled_image(o, **kwargs):
"Call `show_image` destructuring `o` to `(img,title)`"
show_image(o[0], title=str(o[1]), **kwargs)
# Cell
def show_images(ims, nrows=1, ncols=None, titles=None, **kwargs):
"Show all images `ims` as subplots with `rows` using `titles`."
if ncols is None: ncols = int(math.ceil(len(ims)/nrows))
if titles is None: titles = [None]*len(ims)
axs = subplots(nrows, ncols, **kwargs)[1].flat
for im,t,ax in zip(ims, titles, axs): show_image(im, ax=ax, title=t)
# Cell
class ArrayBase(ndarray):
"An `ndarray` that can modify casting behavior"
def _before_cast(cls, x): return x if isinstance(x,ndarray) else array(x)
# Cell
class ArrayImageBase(ArrayBase):
"Base class for arrays representing images"
_show_args = {'cmap':'viridis'}
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
# Cell
class ArrayImage(ArrayImageBase):
"An array representing an image"
# Cell
class ArrayImageBW(ArrayImage):
"An array representing an image"
_show_args = {'cmap':'Greys'}
# Cell
class ArrayMask(ArrayImageBase):
"An array representing an image mask"
_show_args = {'alpha':0.5, 'cmap':'tab20', 'interpolation':'nearest'}
# Cell
def __array_eq__(self:Tensor,b):
return torch.equal(self,b) if self.dim() else self==b
# Cell
def _array2tensor(x):
if x.dtype==np.uint16: x = x.astype(np.float32)
# windows default numpy int dytpe is int32, while torch tensor default int dtype is int64
if sys.platform == "win32":
if x = x.astype(np.int64)
return torch.from_numpy(x)
# Cell
@use_kwargs_dict(dtype=None, device=None, requires_grad=False, pin_memory=False)
def tensor(x, *rest, **kwargs):
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
if len(rest): x = (x,)+rest
# There was a Pytorch bug in dataloader using num_workers>0. Haven't confirmed if fixed
# if isinstance(x, (tuple,list)) and len(x)==0: return tensor(0)
res = (x if isinstance(x, Tensor)
else torch.tensor(x, **kwargs) if isinstance(x, (tuple,list))
else _array2tensor(x) if isinstance(x, ndarray)
else as_tensor(x.values, **kwargs) if isinstance(x, (pd.Series, pd.DataFrame))
else as_tensor(x, **kwargs) if hasattr(x, '__array__') or is_iter(x)
else _array2tensor(array(x), **kwargs))
if res.dtype is torch.float64: return res.float()
return res
# Cell
def set_seed(s, reproducible=False):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try: torch.manual_seed(s)
except NameError: pass
try: torch.cuda.manual_seed_all(s)
except NameError: pass
try: np.random.seed(s%(2**32-1))
except NameError: pass
if reproducible:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Cell
def get_random_states():
"Gets states for `random`, `torch`, and `numpy` random number generators"
return {'random_state':random.getstate(),
# Cell
def set_random_states(random_state,numpy_state,torch_state,torch_cuda_state,torch_deterministic,torch_benchmark):
"Set states for `random`, `torch`, and `numpy` random number generators"
# Cell
def no_random(seed=42,reproducible=True):
"Stores and retrieves state of random number generators. Sets random seed for `random`, `torch`, and `numpy`."
states = get_random_states()
yield #we are managing global variables
# Cell
def unsqueeze(x, dim=-1, n=1):
"Same as `torch.unsqueeze` but can add `n` dims"
for _ in range(n): x = x.unsqueeze(dim)
return x
# Cell
def unsqueeze_(x, dim=-1, n=1):
"Same as `torch.unsqueeze_` but can add `n` dims"
for _ in range(n): x.unsqueeze_(dim)
return x
# Cell
def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2(*args, **kwargs))
def _fa_rebuild_qtensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs))
# Cell
def apply(func, x, *args, **kwargs):
"Apply `func` recursively to `x`, passing on args"
if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
if isinstance(x,dict): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
res = func(x, *args, **kwargs)
return res if x is None else retain_type(res, x)
# Cell
def maybe_gather(x, axis=0):
"Gather copies of `x` on `axis` (if training is distributed)"
if num_distrib()<=1: return x
ndim = x.ndim
res = [x.new_zeros(*x.shape if ndim > 0 else (1,)) for _ in range(num_distrib())]
torch.distributed.all_gather(res, x.contiguous() if ndim > 0 else x[None])
return, dim=axis) if ndim > 0 else, dim=axis).mean()
# Cell
def to_detach(b, cpu=True, gather=True):
"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`."
def _inner(x, cpu=True, gather=True):
if not isinstance(x,Tensor): return x
x = x.detach()
if gather: x = maybe_gather(x)
return x.cpu() if cpu else x
return apply(_inner, b, cpu=cpu, gather=gather)
# Cell
def to_half(b):
"Recursively map lists of tensors in `b ` to FP16."
return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)
# Cell
def to_float(b):
"Recursively map lists of int tensors in `b ` to float."
return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)
# Cell
# None: True if available; True: error if not available; False: use CPU
defaults.use_cuda = None
# Cell
def default_device(use_cuda=-1):
"Return or set default device; `use_cuda`: None - CUDA if available; True - error if not available; False - CPU"
if use_cuda != -1: defaults.use_cuda=use_cuda
use = defaults.use_cuda or (torch.cuda.is_available() and defaults.use_cuda is None)
assert torch.cuda.is_available() or not use
return torch.device(torch.cuda.current_device()) if use else torch.device('cpu')
# Cell
def to_device(b, device=None):
"Recursively put `b` on `device`."
if defaults.use_cuda==False: device='cpu'
elif device is None: device=default_device()
def _inner(o): return, non_blocking=True) if isinstance(o,Tensor) else o.to_device(device) if hasattr(o, "to_device") else o
return apply(_inner, b)
# Cell
def to_cpu(b):
"Recursively map lists of tensors in `b ` to the cpu."
return to_device(b,'cpu')
# Cell
def to_np(x):
"Convert a tensor to a numpy array."
return apply(lambda o:, x)
# Cell
def to_concat(xs, dim=0):
"Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
if not xs: return xs
if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
if isinstance(xs[0],dict): return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
#We may receive xs that are not concatenable (inputs of a text classifier for instance),
# in this case we return a big list
try: return retain_type(, dim=dim), xs[0])
except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
for i in range_of(o_)) for o_ in xs], L())
# Cell
def set_meta(self:Tensor, x, as_copy=False):
"Set all metadata in `__dict__`"
if not hasattr(x,'__dict__'): return
# XXX: change to `deepcopy` once PyTorch 1.7.1 is out, and check nb 23 segmentation fit works
self.__dict__ = copy(x.__dict__) if as_copy else x.__dict__
# Cell
if not hasattr(torch,'as_subclass'): torch.as_subclass = torch.Tensor.as_subclass
# Cell
def as_subclass(self:Tensor, typ):
"Cast to `typ` and include `__dict__` and meta"
return retain_meta(self, torch.as_subclass(self, typ))
# Cell
def _torch_handled(args, opt, func):
if func not in opt: return False
for oks in opt[func]:
if all(isinstance(arg,ok) for arg,ok in zip(args,oks) if ok): return True
# Cell
class TensorBase(Tensor):
"A `Tensor` which support subclass pickling, and maintains metadata when casting or after methods"
debug,_opt = False,defaultdict(list)
def __new__(cls, x, **kwargs):
res = cast(tensor(x), cls)
for k,v in kwargs.items(): setattr(res, k, v)
return res
def _before_cast(cls, x): return tensor(x)
def __repr__(self): return re.sub('tensor', self.__class__.__name__, super().__repr__())
def __reduce_ex__(self,proto):
args = (type(self),, self.storage_offset(), tuple(self.size()), self.stride())
if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
f = _fa_rebuild_qtensor if self.is_quantized else _fa_rebuild_tensor
return (f, args + (self.requires_grad, OrderedDict()))
def register_func(cls, func, *oks): cls._opt[func].append(oks)
def __torch_function__(self, func, types, args=(), kwargs=None):
if self.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
if convert: res = convert(res)
if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)
return res
def new_tensor(self, size, dtype=None, device=None, requires_grad=False):
cls = type(self)
return self.as_subclass(Tensor).new_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)
def new_ones(self, data, dtype=None, device=None, requires_grad=False):
cls = type(self)
return self.as_subclass(Tensor).new_ones(data, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)
def new(self, x=None):
cls = type(self)
res = self.as_subclass(Tensor).new() if x is None else self.as_subclass(Tensor).new(x)
return res.as_subclass(cls)
def requires_grad_(self, requires_grad=True):
# Workaround
self.requires_grad = requires_grad
return self
# Cell
class TensorImageBase(TensorBase):
_show_args = ArrayImageBase._show_args
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
# Cell
class TensorImage(TensorImageBase): pass
# Cell
class TensorImageBW(TensorImage): _show_args = ArrayImageBW._show_args
# Cell
class TensorMask(TensorImageBase):
_show_args = ArrayMask._show_args
def show(self, ctx=None, **kwargs):
codes = getattr(self, 'codes', None)
if codes is not None: kwargs = merge({'vmin': 0, 'vmax': len(codes)}, kwargs)
return super().show(ctx=ctx, **kwargs)
# Cell
for o in Tensor.__ne__,Tensor.__eq__,Tensor.add,Tensor.sub,Tensor.mul,Tensor.div,Tensor.__rsub__,Tensor.__radd__,Tensor.matmul,Tensor.bmm:
TensorBase.register_func(o, TensorMask, TensorImageBase)
TensorBase.register_func(o, TensorImageBase, TensorMask)
TensorMask.register_func(torch.einsum, str, TensorImageBase, TensorMask)
TensorMask.register_func(torch.einsum, str, TensorMask, TensorImageBase)
# Cell
class TensorFlowField(TensorBase): pass
TensorImage.register_func(F.grid_sample, TensorImageBase, TensorFlowField)
# Cell
class TensorCategory(TensorBase): pass
# Cell
class TensorMultiCategory(TensorCategory): pass
# Cell
class TitledTensorScalar(TensorBase):
"A tensor containing a scalar that has a `show` method"
def show(self, **kwargs): show_title(self.item(), **kwargs)
# Cell
def tensored(self:L):
def stack(self:L, dim=0):
"Same as `torch.stack`"
return torch.stack(list(self.tensored()), dim=dim)
def cat (self:L, dim=0):
"Same as ``"
return (list(self.tensored()), dim=dim)
# Cell
def concat(*ls):
"Concatenate tensors, arrays, lists, or tuples"
if not len(ls): return []
it = ls[0]
if isinstance(it,torch.Tensor): res =
elif isinstance(it,ndarray): res = np.concatenate(ls)
res = itertools.chain.from_iterable(map(L,ls))
if isinstance(it,(tuple,list)): res = type(it)(res)
else: res = L(res)
return retain_type(res, it)
# Cell
class Chunks:
"Slice and int indexing into a list of lists"
def __init__(self, chunks, lens=None):
self.chunks = chunks
self.lens = L(map(len,self.chunks) if lens is None else lens)
self.cumlens = np.cumsum(0+self.lens)
self.totlen = self.cumlens[-1]
def __getitem__(self,i):
if isinstance(i,slice): return retain_type(self.getslice(i), old=self.chunks[0])
di,idx = self.doc_idx(i)
return retain_type(self.chunks[di][idx], old=self.chunks[0])
def getslice(self, i):
st_d,st_i = self.doc_idx(ifnone(i.start,0))
en_d,en_i = self.doc_idx(ifnone(i.stop,self.totlen+1))
res = [self.chunks[st_d][st_i:(en_i if st_d==en_d else sys.maxsize)]]
for b in range(st_d+1,en_d): res.append(self.chunks[b])
if st_d!=en_d and en_d<len(self.chunks): res.append(self.chunks[en_d][:en_i])
return concat(*res)
def doc_idx(self, i):
if i<0: i=self.totlen+i # count from end
docidx = np.searchsorted(self.cumlens, i+1)-1
cl = self.cumlens[docidx]
return docidx,i-cl
# Cell
def show_title(o, ax=None, ctx=None, label=None, color='black', **kwargs):
"Set title of `ax` to `o`, or print `o` if `ax` is `None`"
ax = ifnone(ax,ctx)
if ax is None: print(o)
elif hasattr(ax, 'set_title'):
t = ax.title.get_text()
if len(t) > 0: o = t+'\n'+str(o)
ax.set_title(o, color=color)
elif isinstance(ax, pd.Series):
while label in ax: label += '_'
ax = ax.append(pd.Series({label: o}))
return ax
# Cell
class ShowTitle:
"Base class that adds a simple `show`"
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledInt(Int, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledFloat(Float, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledStr(Str, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledTuple(fastuple, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
add_docs(TitledInt, "An `int` with `show`"); add_docs(TitledStr, "An `str` with `show`");
add_docs(TitledFloat, "A `float` with `show`"); add_docs(TitledTuple, "A `fastuple` with `show`")
# Cell
def truncate(self:TitledStr, n):
"Truncate self to `n`"
words = self.split(' ')[:n]
return TitledStr(' '.join(words))
# Cell
if not hasattr(pd.DataFrame,'_old_init'): pd.DataFrame._old_init = pd.DataFrame.__init__
# Cell
def __init__(self:pd.DataFrame, data=None, index=None, columns=None, dtype=None, copy=False):
if data is not None and isinstance(data, Tensor): data = to_np(data)
self._old_init(data, index=index, columns=columns, dtype=dtype, copy=copy)
# Cell
def get_empty_df(n):
"Return `n` empty rows of a dataframe"
df = pd.DataFrame(index = range(n))
return [df.iloc[i] for i in range(n)]
# Cell
def display_df(df):
"Display `df` in a notebook or defaults to print"
try: from IPython.display import display, HTML
except: return print(df)
# Cell
def get_first(c):
"Get the first element of c, even if c is a dataframe"
return getattr(c, 'iloc', c)[0]
# Cell
def one_param(m):
"First parameter in `m`"
return first(m.parameters())
# Cell
def item_find(x, idx=0):
"Recursively takes the `idx`-th element of `x`"
if is_listy(x): return item_find(x[idx])
if isinstance(x,dict):
key = list(x.keys())[idx] if isinstance(idx, int) else idx
return item_find(x[key])
return x
# Cell
def find_device(b):
"Recursively search the device of `b`."
return item_find(b).device
# Cell
def find_bs(b):
"Recursively search the batch size of `b`."
return item_find(b).shape[0]
# Cell
def np_func(f):
"Convert a function taking and returning numpy arrays to one taking and returning tensors"
def _inner(*args, **kwargs):
nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]
return tensor(f(*nargs, **kwargs))
functools.update_wrapper(_inner, f)
return _inner
# Cell
class Module(nn.Module, metaclass=PrePostInitMeta):
"Same as `nn.Module`, but no need for subclasses to call `super().__init__`"
def __pre_init__(self, *args, **kwargs): super().__init__()
def __init__(self): pass
# Cell
from torch.nn.parallel import DistributedDataParallel
# Cell
def get_model(model):
"Return the model maybe wrapped inside `model`."
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model
# Cell
def one_hot(x, c):
"One-hot encode `x` with `c` classes."
res = torch.zeros(c, dtype=torch.uint8)
if isinstance(x, Tensor) and x.numel()>0: res[x] = 1.
else: res[list(L(x, use_list=None))] = 1.
return res
# Cell
def one_hot_decode(x, vocab=None):
return L(vocab[i] if vocab else i for i,x_ in enumerate(x) if x_==1)
# Cell
def params(m):
"Return all parameters of `m`"
return [p for p in m.parameters()]
# Cell
def trainable_params(m):
"Return all trainable parameters of `m`"
return [p for p in m.parameters() if p.requires_grad]
# Cell
norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm)
# Cell
def norm_bias_params(m, with_bias=True):
"Return all bias and BatchNorm parameters"
if isinstance(m, norm_types): return L(m.parameters())
res = L(m.children()).map(norm_bias_params, with_bias=with_bias).concat()
if with_bias and getattr(m, 'bias', None) is not None: res.append(m.bias)
return res
# Cell
def batch_to_samples(b, max_n=10):
"'Transposes' a batch to (at most `max_n`) samples"
if isinstance(b, Tensor): return retain_types(list(b[:max_n]), [b])
res = L(b).map(partial(batch_to_samples,max_n=max_n))
return retain_types(, [b])
# Cell
def interp_1d(x:Tensor, xp, fp):
"Same as `np.interp`"
slopes = (fp[1:]-fp[:-1])/(xp[1:]-xp[:-1])
incx = fp[:-1] - (slopes*xp[:-1])
locs = (x[:,None]>=xp[None,:]).long().sum(1)-1
locs = locs.clamp(0,len(slopes)-1)
return slopes[locs]*x + incx[locs]
# Cell
def pca(x:Tensor, k=2):
"Compute PCA of `x` with `k` dimensions."
x = x-torch.mean(x,0)
U,S,V = torch.svd(x.t())
# Cell
def logit(x):
"Logit of `x`, clamped to avoid inf."
x = x.clamp(1e-7, 1-1e-7)
return -(1/x-1).log()
# Cell
def num_distrib():
"Return the number of processes in distributed training (if applicable)."
return int(os.environ.get('WORLD_SIZE', 0))
# Cell
def rank_distrib():
"Return the distributed rank of this process (if applicable)."
return int(os.environ.get('RANK', 0))
# Cell
def distrib_barrier():
"Place a synchronization barrier in distributed training"
if num_distrib() > 1 and torch.distributed.is_initialized(): torch.distributed.barrier()
# Cell
# Saving arrays requires pytables - optional dependency
try: import tables
except: pass
# Cell
def _comp_filter(lib='lz4',lvl=3): return tables.Filters(complib=f'blosc:{lib}', complevel=lvl)
# Cell
def save_array(p:Path, o, complib='lz4', lvl=3):
"Save numpy array to a compressed `pytables` file, using compression level `lvl`"
if isinstance(o,Tensor): o = to_np(o)
with tables.open_file(p, mode='w', filters=_comp_filter(lib=complib,lvl=lvl)) as f: f.create_carray('/', 'data', obj=o)
# Cell
def load_array(p:Path):
"Save numpy array to a `pytables` file"
with tables.open_file(p, 'r') as f: return
# Cell
def base_doc(elt):
"Print a base documentation of `elt`"
name = getattr(elt, '__qualname__', getattr(elt, '__name__', ''))
print('To get a prettier result with hyperlinks to source code and documentation, install nbdev: pip install nbdev')
# Cell
def doc(elt):
"Try to use doc form nbdev and fall back to `base_doc`"
from nbdev.showdoc import doc
except: base_doc(elt)
# Cell
def nested_reorder(t, idxs):
"Reorder all tensors in `t` using `idxs`"
if isinstance(t, (Tensor,L)): return t[idxs]
elif is_listy(t): return type(t)(nested_reorder(t_, idxs) for t_ in t)
if t is None: return t
raise TypeError(f"Expected tensor, tuple, list or L but got {type(t)}")
# Cell
def make_cross_image(bw=True):
"Create a tensor containing a cross image, either `bw` (True) or color"
if bw:
im = torch.zeros(5,5)
im[2,:] = 1.
im[:,2] = 1.
im = torch.zeros(3,5,5)
im[0,2,:] = 1.
im[1,:,2] = 1.
return im
# Cell
def show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs):
"Display batch `b` in a grid of size `items` with `cols` width"
if items<cols: cols=items
rows = (items+cols-1) // cols
if figsize is None: figsize = (cols*3, rows*3)
fig,axs = plt.subplots(rows, cols, figsize=figsize)
for *o,ax in zip(*to_cpu(b), axs.flatten()): show(o, ax=ax, **kwargs)
# Cell
def requires_grad(m):
"Check if the first parameter of `m` requires grad or not"
ps = list(m.parameters())
return ps[0].requires_grad if len(ps)>0 else False
# Cell
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func:
if hasattr(m, 'weight'): func(m.weight)
if hasattr(m, 'bias') and hasattr(m.bias, 'data'):
return m
# Cell
def cond_init(m, func):
"Apply `init_default` to `m` unless it's a batchnorm module"
if (not isinstance(m, norm_types)) and requires_grad(m): init_default(m, func)
# Cell
def apply_leaf(m, f):
"Apply `f` to children of `m`."
c = m.children()
if isinstance(m, nn.Module): f(m)
for l in c: apply_leaf(l,f)
# Cell
def apply_init(m, func=nn.init.kaiming_normal_):
"Initialize all non-batchnorm layers of `m` with `func`."
apply_leaf(m, partial(cond_init, func=func))
# Cell
def script_use_ctx(f):
"Decorator: create jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs): return sf(*args, *ctx.saved_variables, **kwargs)
return update_wrapper(_f,f)
# Cell
def script_save_ctx(static, *argidx):
"Decorator: create jit script and save args with indices `argidx` using `ctx.save_for_backward`"
def _dec(f):
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs):
if argidx:
save = [args[o] for o in argidx]
if not argidx: args = [ctx]+args
return sf(*args, **kwargs)
if static: _f = staticmethod(_f)
return update_wrapper(_f,f)
return _dec
# Cell
def script_fwd(*argidx):
"Decorator: create static jit script and save args with indices `argidx` using `ctx.save_for_backward`"
return script_save_ctx(True, *argidx)
# Cell
def script_bwd(f):
"Decorator: create static jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
return staticmethod(script_use_ctx(f))
# Cell
def grad_module(cls):
"Decorator: convert `cls` into an autograd function"
class _c(nn.Module):
def forward(self, *args, **kwargs): return cls.apply(*args, **kwargs)
return _c
# Comes from 13b_metrics.ipynb, cell
def flatten_check(inp, targ):
"Check that `out` and `targ` have the same number of elements and flatten them."
inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
test_eq(len(inp), len(targ))
return inp,targ