-
Notifications
You must be signed in to change notification settings - Fork 7.5k
/
torch_core.py
826 lines (702 loc) · 29.2 KB
/
torch_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_torch_core.ipynb (unless otherwise specified).
__all__ = ['progress_bar', 'master_bar', 'subplots', 'show_image', 'show_titled_image', 'show_images', 'ArrayBase',
'ArrayImageBase', 'ArrayImage', 'ArrayImageBW', 'ArrayMask', 'tensor', 'set_seed', 'get_random_states',
'set_random_states', 'no_random', 'unsqueeze', 'unsqueeze_', 'apply', 'maybe_gather', 'to_detach', 'to_half',
'to_float', 'default_device', 'to_device', 'to_cpu', 'to_np', 'to_concat', 'TensorBase', 'TensorImageBase',
'TensorImage', 'TensorImageBW', 'TensorMask', 'TensorFlowField', 'TensorCategory', 'TensorMultiCategory',
'TitledTensorScalar', 'concat', 'Chunks', 'show_title', 'ShowTitle', 'TitledInt', 'TitledFloat', 'TitledStr',
'TitledTuple', 'get_empty_df', 'display_df', 'get_first', 'one_param', 'item_find', 'find_device', 'find_bs',
'np_func', 'Module', 'get_model', 'one_hot', 'one_hot_decode', 'params', 'trainable_params', 'norm_types',
'norm_bias_params', 'batch_to_samples', 'logit', 'num_distrib', 'rank_distrib', 'distrib_barrier',
'base_doc', 'doc', 'nested_reorder', 'make_cross_image', 'show_image_batch', 'requires_grad', 'init_default',
'cond_init', 'apply_leaf', 'apply_init', 'script_use_ctx', 'script_save_ctx', 'script_fwd', 'script_bwd',
'grad_module', 'ismin_torch', 'notmax_torch', 'flatten_check']
# Cell
from .imports import *
from .torch_imports import *
from packaging.version import parse
# Cell
#nbdev_comment _all_ = ['progress_bar','master_bar']
# Cell
if torch.cuda.is_available():
if torch.cuda.current_device()==0:
def_gpu = int(os.environ.get('DEFAULT_GPU') or 0)
if torch.cuda.device_count()>=def_gpu: torch.cuda.set_device(def_gpu)
torch.backends.cudnn.benchmark = True
# Cell
@delegates(plt.subplots, keep=True)
def subplots(nrows=1, ncols=1, figsize=None, imsize=3,suptitle=None, **kwargs):
if figsize is None:
h=nrows*imsize if suptitle is None or imsize>2 else nrows*imsize+0.6 #https://github.com/matplotlib/matplotlib/issues/5355
figsize=(ncols*imsize, h)
fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
if suptitle is not None: fig.suptitle(suptitle)
if nrows*ncols==1: ax = array([ax])
return fig,ax
# Cell
def _fig_bounds(x):
r = x//32
return min(5, max(1,r))
# Cell
@delegates(plt.Axes.imshow, keep=True, but=['shape', 'imlim'])
def show_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):
"Show a PIL or PyTorch image on `ax`."
# Handle pytorch axis order
if hasattrs(im, ('data','cpu','permute')):
im = im.data.cpu()
if im.shape[0]<5: im=im.permute(1,2,0)
elif not isinstance(im,np.ndarray): im=array(im)
# Handle 1-channel images
if im.shape[-1]==1: im=im[...,0]
ax = ifnone(ax,ctx)
if figsize is None: figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1]))
if ax is None: _,ax = plt.subplots(figsize=figsize)
ax.imshow(im, **kwargs)
if title is not None: ax.set_title(title)
ax.axis('off')
return ax
# Cell
@delegates(show_image, keep=True)
def show_titled_image(o, **kwargs):
"Call `show_image` destructuring `o` to `(img,title)`"
show_image(o[0], title=str(o[1]), **kwargs)
# Cell
@delegates(subplots)
def show_images(ims, nrows=1, ncols=None, titles=None, **kwargs):
"Show all images `ims` as subplots with `rows` using `titles`."
if ncols is None: ncols = int(math.ceil(len(ims)/nrows))
if titles is None: titles = [None]*len(ims)
axs = subplots(nrows, ncols, **kwargs)[1].flat
for im,t,ax in zip(ims, titles, axs): show_image(im, ax=ax, title=t)
# Cell
class ArrayBase(ndarray):
"An `ndarray` that can modify casting behavior"
@classmethod
def _before_cast(cls, x): return x if isinstance(x,ndarray) else array(x)
# Cell
class ArrayImageBase(ArrayBase):
"Base class for arrays representing images"
_show_args = {'cmap':'viridis'}
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
# Cell
class ArrayImage(ArrayImageBase):
"An array representing an image"
pass
# Cell
class ArrayImageBW(ArrayImage):
"An array representing an image"
_show_args = {'cmap':'Greys'}
# Cell
class ArrayMask(ArrayImageBase):
"An array representing an image mask"
_show_args = {'alpha':0.5, 'cmap':'tab20', 'interpolation':'nearest'}
# Cell
@patch
def __array_eq__(self:Tensor,b):
return torch.equal(self,b) if self.dim() else self==b
# Cell
def _array2tensor(x):
if x.dtype==np.uint16: x = x.astype(np.float32)
# windows default numpy int dytpe is int32, while torch tensor default int dtype is int64
# https://github.com/numpy/numpy/issues/9464
if sys.platform == "win32":
if x.dtype==np.int: x = x.astype(np.int64)
return torch.from_numpy(x)
# Cell
@use_kwargs_dict(dtype=None, device=None, requires_grad=False, pin_memory=False)
def tensor(x, *rest, **kwargs):
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
if len(rest): x = (x,)+rest
# There was a Pytorch bug in dataloader using num_workers>0. Haven't confirmed if fixed
# if isinstance(x, (tuple,list)) and len(x)==0: return tensor(0)
res = (x if isinstance(x, Tensor)
else torch.tensor(x, **kwargs) if isinstance(x, (tuple,list))
else _array2tensor(x) if isinstance(x, ndarray)
else as_tensor(x.values, **kwargs) if isinstance(x, (pd.Series, pd.DataFrame))
# else as_tensor(array(x, **kwargs)) if hasattr(x, '__array__') or is_iter(x)
else _array2tensor(array(x), **kwargs))
if res.dtype is torch.float64: return res.float()
return res
# Cell
def set_seed(s, reproducible=False):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try: torch.manual_seed(s)
except NameError: pass
try: torch.cuda.manual_seed_all(s)
except NameError: pass
try: np.random.seed(s%(2**32-1))
except NameError: pass
random.seed(s)
if reproducible:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Cell
def get_random_states():
"Gets states for `random`, `torch`, and `numpy` random number generators"
return {'random_state':random.getstate(),
'numpy_state':np.random.get_state(),
'torch_state':torch.get_rng_state(),
'torch_cuda_state':torch.cuda.get_rng_state_all(),
'torch_deterministic':torch.backends.cudnn.deterministic,
'torch_benchmark':torch.backends.cudnn.benchmark}
# Cell
def set_random_states(random_state,numpy_state,torch_state,torch_cuda_state,torch_deterministic,torch_benchmark):
"Set states for `random`, `torch`, and `numpy` random number generators"
random.setstate(random_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)
torch.cuda.set_rng_state_all(torch_cuda_state)
torch.backends.cudnn.deterministic=torch_deterministic
torch.backends.cudnn.benchmark=torch_benchmark
# Cell
@contextmanager
def no_random(seed=42,reproducible=True):
"Stores and retrieves state of random number generators. Sets random seed for `random`, `torch`, and `numpy`."
states = get_random_states()
set_seed(seed,reproducible=reproducible)
try:
yield #we are managing global variables
finally:
set_random_states(**states)
# Cell
def unsqueeze(x, dim=-1, n=1):
"Same as `torch.unsqueeze` but can add `n` dims"
for _ in range(n): x = x.unsqueeze(dim)
return x
# Cell
def unsqueeze_(x, dim=-1, n=1):
"Same as `torch.unsqueeze_` but can add `n` dims"
for _ in range(n): x.unsqueeze_(dim)
return x
# Cell
def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2(*args, **kwargs))
def _fa_rebuild_qtensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs))
# Cell
def apply(func, x, *args, **kwargs):
"Apply `func` recursively to `x`, passing on args"
if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
if isinstance(x,dict): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
res = func(x, *args, **kwargs)
return res if x is None else retain_type(res, x)
# Cell
def maybe_gather(x, axis=0):
"Gather copies of `x` on `axis` (if training is distributed)"
if num_distrib()<=1: return x
ndim = x.ndim
res = [x.new_zeros(*x.shape if ndim > 0 else (1,)) for _ in range(num_distrib())]
torch.distributed.all_gather(res, x.contiguous() if ndim > 0 else x[None])
return torch.cat(res, dim=axis) if ndim > 0 else torch.cat(res, dim=axis).mean()
# Cell
def to_detach(b, cpu=True, gather=True):
"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`."
def _inner(x, cpu=True, gather=True):
if not isinstance(x,Tensor): return x
x = x.detach()
if gather: x = maybe_gather(x)
return x.cpu() if cpu else x
return apply(_inner, b, cpu=cpu, gather=gather)
# Cell
def to_half(b):
"Recursively map lists of tensors in `b ` to FP16."
return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)
# Cell
def to_float(b):
"Recursively map lists of int tensors in `b ` to float."
return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)
# Cell
# None: True if available; True: error if not available; False: use CPU
defaults.use_cuda = None
# Cell
def default_device(use_cuda=-1):
"Return or set default device; `use_cuda`: None - CUDA if available; True - error if not available; False - CPU"
if use_cuda != -1: defaults.use_cuda=use_cuda
use = defaults.use_cuda or (torch.cuda.is_available() and defaults.use_cuda is None)
assert torch.cuda.is_available() or not use
return torch.device(torch.cuda.current_device()) if use else torch.device('cpu')
# Cell
def to_device(b, device=None, non_blocking=False):
"Recursively put `b` on `device`."
if defaults.use_cuda==False: device='cpu'
elif device is None: device=default_device()
def _inner(o):
if isinstance(o,Tensor): return o.to(device, non_blocking=non_blocking)
# if hasattr(o, "to_device"): return o.to_device(device)
return o
return apply(_inner, b)
# Cell
def to_cpu(b):
"Recursively map lists of tensors in `b ` to the cpu."
return to_device(b,'cpu')
# Cell
def to_np(x):
"Convert a tensor to a numpy array."
return apply(lambda o: o.data.cpu().numpy(), x)
# Cell
def to_concat(xs, dim=0):
"Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
if not xs: return xs
if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
if isinstance(xs[0],dict): return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
#We may receive xs that are not concatenable (inputs of a text classifier for instance),
# in this case we return a big list
try: return retain_type(torch.cat(xs, dim=dim), xs[0])
except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
for i in range_of(o_)) for o_ in xs], L())
# Cell
@patch
def set_meta(self:Tensor, x, as_copy=False):
"Set all metadata in `__dict__`"
if not hasattr(x,'__dict__'): return
# XXX: change to `deepcopy` once PyTorch 1.7.1 is out, and check nb 23 segmentation fit works
self.__dict__ = copy(x.__dict__) if as_copy else x.__dict__
# Cell
if not hasattr(torch,'as_subclass'): torch.as_subclass = torch.Tensor.as_subclass
# Cell
@patch
def as_subclass(self:Tensor, typ):
"Cast to `typ` and include `__dict__` and meta"
return retain_meta(self, torch.as_subclass(self, typ))
# Cell
def _torch_handled(args, opt, func):
if func not in opt: return False
for oks in opt[func]:
if all(isinstance(arg,ok) for arg,ok in zip(args,oks) if ok): return True
# Cell
# from https://github.com/pytorch/pytorch/blob/13c975684a220ec096216ec6468ccd0dc90ff50a/torch/_tensor.py#L34
def _rebuild_from_type(func, type, args, dict):
ret = func(*args).as_subclass(type)
ret.__dict__ = dict
return ret
# Cell
class TensorBase(Tensor):
"A `Tensor` which support subclass pickling, and maintains metadata when casting or after methods"
debug,_opt = False,defaultdict(list)
def __new__(cls, x, **kwargs):
res = cast(tensor(x), cls)
for k,v in kwargs.items(): setattr(res, k, v)
return res
@classmethod
def _before_cast(cls, x): return tensor(x)
def __repr__(self): return re.sub('tensor', self.__class__.__name__, super().__repr__())
def __reduce_ex__(self,proto):
torch.utils.hooks.warn_if_has_hooks(self)
args = (self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
args = args + (self.requires_grad, OrderedDict())
f = torch._utils._rebuild_qtensor if self.is_quantized else torch._utils._rebuild_tensor_v2
return (_rebuild_from_type, (f, type(self), args, self.__dict__))
@classmethod
def register_func(cls, func, *oks): cls._opt[func].append(oks)
def __torch_function__(self, func, types, args=(), kwargs=None):
if self.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
convert=False
if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
if convert: res = convert(res)
if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)
return res
def new_tensor(self, size, dtype=None, device=None, requires_grad=False):
cls = type(self)
return self.as_subclass(Tensor).new_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)
def new_ones(self, data, dtype=None, device=None, requires_grad=False):
cls = type(self)
return self.as_subclass(Tensor).new_ones(data, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)
def new(self, x=None):
cls = type(self)
res = self.as_subclass(Tensor).new() if x is None else self.as_subclass(Tensor).new(x)
return res.as_subclass(cls)
def requires_grad_(self, requires_grad=True):
# Workaround https://github.com/pytorch/pytorch/issues/50219
self.requires_grad = requires_grad
return self
# Cell
class TensorImageBase(TensorBase):
_show_args = ArrayImageBase._show_args
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
# Cell
class TensorImage(TensorImageBase): pass
# Cell
class TensorImageBW(TensorImage): _show_args = ArrayImageBW._show_args
# Cell
class TensorMask(TensorImageBase):
_show_args = ArrayMask._show_args
def show(self, ctx=None, **kwargs):
codes = getattr(self, 'codes', None)
if codes is not None: kwargs = merge({'vmin': 0, 'vmax': len(codes)}, kwargs)
return super().show(ctx=ctx, **kwargs)
# Cell
for o in Tensor.__getitem__, Tensor.__ne__,Tensor.__eq__,Tensor.add,Tensor.sub,Tensor.mul,Tensor.div,Tensor.__rsub__,Tensor.__radd__,Tensor.matmul,Tensor.bmm:
TensorBase.register_func(o, TensorMask, TensorImageBase)
TensorBase.register_func(o, TensorImageBase, TensorMask)
TensorMask.register_func(torch.einsum, str, TensorImageBase, TensorMask)
TensorMask.register_func(torch.einsum, str, TensorMask, TensorImageBase)
# Cell
class TensorFlowField(TensorBase): pass
TensorImage.register_func(F.grid_sample, TensorImageBase, TensorFlowField)
# Cell
class TensorCategory(TensorBase): pass
TensorBase.register_func(Tensor.__getitem__, TensorImageBase, TensorCategory)
# Cell
class TensorMultiCategory(TensorCategory): pass
# Cell
class TitledTensorScalar(TensorBase):
"A tensor containing a scalar that has a `show` method"
def show(self, **kwargs): show_title(self.item(), **kwargs)
# Cell
@patch
def tensored(self:L):
"`mapped(tensor)`"
return self.map(tensor)
@patch
def stack(self:L, dim=0):
"Same as `torch.stack`"
return torch.stack(list(self.tensored()), dim=dim)
@patch
def cat (self:L, dim=0):
"Same as `torch.cat`"
return torch.cat (list(self.tensored()), dim=dim)
# Cell
def concat(*ls):
"Concatenate tensors, arrays, lists, or tuples"
if not len(ls): return []
it = ls[0]
if isinstance(it,torch.Tensor): res = torch.cat(ls)
elif isinstance(it,ndarray): res = np.concatenate(ls)
else:
res = itertools.chain.from_iterable(map(L,ls))
if isinstance(it,(tuple,list)): res = type(it)(res)
else: res = L(res)
return retain_type(res, it)
# Cell
class Chunks:
"Slice and int indexing into a list of lists"
def __init__(self, chunks, lens=None):
self.chunks = chunks
self.lens = L(map(len,self.chunks) if lens is None else lens)
self.cumlens = np.cumsum(0+self.lens)
self.totlen = self.cumlens[-1]
def __getitem__(self,i):
if isinstance(i,slice): return retain_type(self.getslice(i), old=self.chunks[0])
di,idx = self.doc_idx(i)
return retain_type(self.chunks[di][idx], old=self.chunks[0])
def getslice(self, i):
st_d,st_i = self.doc_idx(ifnone(i.start,0))
en_d,en_i = self.doc_idx(ifnone(i.stop,self.totlen+1))
res = [self.chunks[st_d][st_i:(en_i if st_d==en_d else sys.maxsize)]]
for b in range(st_d+1,en_d): res.append(self.chunks[b])
if st_d!=en_d and en_d<len(self.chunks): res.append(self.chunks[en_d][:en_i])
return concat(*res)
def doc_idx(self, i):
if i<0: i=self.totlen+i # count from end
docidx = np.searchsorted(self.cumlens, i+1)-1
cl = self.cumlens[docidx]
return docidx,i-cl
# Cell
def show_title(o, ax=None, ctx=None, label=None, color='black', **kwargs):
"Set title of `ax` to `o`, or print `o` if `ax` is `None`"
ax = ifnone(ax,ctx)
if ax is None: print(o)
elif hasattr(ax, 'set_title'):
t = ax.title.get_text()
if len(t) > 0: o = t+'\n'+str(o)
ax.set_title(o, color=color)
elif isinstance(ax, pd.Series):
while label in ax: label += '_'
ax = ax.append(pd.Series({label: o}))
return ax
# Cell
class ShowTitle:
"Base class that adds a simple `show`"
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledInt(Int, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledFloat(Float, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledStr(Str, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledTuple(fastuple, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
add_docs(TitledInt, "An `int` with `show`"); add_docs(TitledStr, "An `str` with `show`");
add_docs(TitledFloat, "A `float` with `show`"); add_docs(TitledTuple, "A `fastuple` with `show`")
# Cell
@patch
def truncate(self:TitledStr, n):
"Truncate self to `n`"
words = self.split(' ')[:n]
return TitledStr(' '.join(words))
# Cell
if not hasattr(pd.DataFrame,'_old_init'): pd.DataFrame._old_init = pd.DataFrame.__init__
# Cell
@patch
def __init__(self:pd.DataFrame, data=None, index=None, columns=None, dtype=None, copy=None):
if data is not None and isinstance(data, Tensor): data = to_np(data)
self._old_init(data, index=index, columns=columns, dtype=dtype, copy=copy)
# Cell
def get_empty_df(n):
"Return `n` empty rows of a dataframe"
df = pd.DataFrame(index = range(n))
return [df.iloc[i] for i in range(n)]
# Cell
def display_df(df):
"Display `df` in a notebook or defaults to print"
try: from IPython.display import display, HTML
except: return print(df)
display(HTML(df.to_html()))
# Cell
def get_first(c):
"Get the first element of c, even if c is a dataframe"
return getattr(c, 'iloc', c)[0]
# Cell
def one_param(m):
"First parameter in `m`"
return first(m.parameters())
# Cell
def item_find(x, idx=0):
"Recursively takes the `idx`-th element of `x`"
if is_listy(x): return item_find(x[idx])
if isinstance(x,dict):
key = list(x.keys())[idx] if isinstance(idx, int) else idx
return item_find(x[key])
return x
# Cell
def find_device(b):
"Recursively search the device of `b`."
return item_find(b).device
# Cell
def find_bs(b):
"Recursively search the batch size of `b`."
return item_find(b).shape[0]
# Cell
def np_func(f):
"Convert a function taking and returning numpy arrays to one taking and returning tensors"
def _inner(*args, **kwargs):
nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]
return tensor(f(*nargs, **kwargs))
functools.update_wrapper(_inner, f)
return _inner
# Cell
class Module(nn.Module, metaclass=PrePostInitMeta):
"Same as `nn.Module`, but no need for subclasses to call `super().__init__`"
def __pre_init__(self, *args, **kwargs): super().__init__()
def __init__(self): pass
# Cell
from torch.nn.parallel import DistributedDataParallel
# Cell
def get_model(model):
"Return the model maybe wrapped inside `model`."
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model
# Cell
def one_hot(x, c):
"One-hot encode `x` with `c` classes."
res = torch.zeros(c, dtype=torch.uint8)
if isinstance(x, Tensor) and x.numel()>0: res[x] = 1.
else: res[list(L(x, use_list=None))] = 1.
return res
# Cell
def one_hot_decode(x, vocab=None):
return L(vocab[i] if vocab else i for i,x_ in enumerate(x) if x_==1)
# Cell
def params(m):
"Return all parameters of `m`"
return [p for p in m.parameters()]
# Cell
def trainable_params(m):
"Return all trainable parameters of `m`"
return [p for p in m.parameters() if p.requires_grad]
# Cell
norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm)
# Cell
def norm_bias_params(m, with_bias=True):
"Return all bias and BatchNorm parameters"
if isinstance(m, norm_types): return L(m.parameters())
res = L(m.children()).map(norm_bias_params, with_bias=with_bias).concat()
if with_bias and getattr(m, 'bias', None) is not None: res.append(m.bias)
return res
# Cell
def batch_to_samples(b, max_n=10):
"'Transposes' a batch to (at most `max_n`) samples"
if isinstance(b, Tensor): return retain_types(list(b[:max_n]), [b])
else:
res = L(b).map(partial(batch_to_samples,max_n=max_n))
return retain_types(res.zip(), [b])
# Cell
@patch
def interp_1d(x:Tensor, xp, fp):
"Same as `np.interp`"
slopes = (fp[1:]-fp[:-1])/(xp[1:]-xp[:-1])
incx = fp[:-1] - (slopes*xp[:-1])
locs = (x[:,None]>=xp[None,:]).long().sum(1)-1
locs = locs.clamp(0,len(slopes)-1)
return slopes[locs]*x + incx[locs]
# Cell
@patch
def pca(x:Tensor, k=2):
"Compute PCA of `x` with `k` dimensions."
x = x-torch.mean(x,0)
U,S,V = torch.svd(x.t())
return torch.mm(x,U[:,:k])
# Cell
def logit(x):
"Logit of `x`, clamped to avoid inf."
x = x.clamp(1e-7, 1-1e-7)
return -(1/x-1).log()
# Cell
def num_distrib():
"Return the number of processes in distributed training (if applicable)."
return int(os.environ.get('WORLD_SIZE', 0))
# Cell
def rank_distrib():
"Return the distributed rank of this process (if applicable)."
return int(os.environ.get('RANK', 0))
# Cell
def distrib_barrier():
"Place a synchronization barrier in distributed training"
if num_distrib() > 1 and torch.distributed.is_initialized(): torch.distributed.barrier()
# Cell
# Saving arrays requires pytables - optional dependency
try: import tables
except: pass
# Cell
def _comp_filter(lib='lz4',lvl=3): return tables.Filters(complib=f'blosc:{lib}', complevel=lvl)
# Cell
@patch
def save_array(p:Path, o, complib='lz4', lvl=3):
"Save numpy array to a compressed `pytables` file, using compression level `lvl`"
if isinstance(o,Tensor): o = to_np(o)
with tables.open_file(p, mode='w', filters=_comp_filter(lib=complib,lvl=lvl)) as f: f.create_carray('/', 'data', obj=o)
# Cell
@patch
def load_array(p:Path):
"Save numpy array to a `pytables` file"
with tables.open_file(p, 'r') as f: return f.root.data.read()
# Cell
def base_doc(elt):
"Print a base documentation of `elt`"
name = getattr(elt, '__qualname__', getattr(elt, '__name__', ''))
print(f'{name}{inspect.signature(elt)}\n{inspect.getdoc(elt)}\n')
print('To get a prettier result with hyperlinks to source code and documentation, install nbdev: pip install nbdev')
# Cell
def doc(elt):
"Try to use doc form nbdev and fall back to `base_doc`"
try:
from nbdev.showdoc import doc
doc(elt)
except: base_doc(elt)
# Cell
def nested_reorder(t, idxs):
"Reorder all tensors in `t` using `idxs`"
if isinstance(t, (Tensor,L)): return t[idxs]
elif is_listy(t): return type(t)(nested_reorder(t_, idxs) for t_ in t)
if t is None: return t
raise TypeError(f"Expected tensor, tuple, list or L but got {type(t)}")
# Cell
def make_cross_image(bw=True):
"Create a tensor containing a cross image, either `bw` (True) or color"
if bw:
im = torch.zeros(5,5)
im[2,:] = 1.
im[:,2] = 1.
else:
im = torch.zeros(3,5,5)
im[0,2,:] = 1.
im[1,:,2] = 1.
return im
# Cell
def show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs):
"Display batch `b` in a grid of size `items` with `cols` width"
if items<cols: cols=items
rows = (items+cols-1) // cols
if figsize is None: figsize = (cols*3, rows*3)
fig,axs = plt.subplots(rows, cols, figsize=figsize)
for *o,ax in zip(*to_cpu(b), axs.flatten()): show(o, ax=ax, **kwargs)
# Cell
def requires_grad(m):
"Check if the first parameter of `m` requires grad or not"
ps = list(m.parameters())
return ps[0].requires_grad if len(ps)>0 else False
# Cell
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func:
if hasattr(m, 'weight'): func(m.weight)
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
return m
# Cell
def cond_init(m, func):
"Apply `init_default` to `m` unless it's a batchnorm module"
if (not isinstance(m, norm_types)) and requires_grad(m): init_default(m, func)
# Cell
def apply_leaf(m, f):
"Apply `f` to children of `m`."
c = m.children()
if isinstance(m, nn.Module): f(m)
for l in c: apply_leaf(l,f)
# Cell
def apply_init(m, func=nn.init.kaiming_normal_):
"Initialize all non-batchnorm layers of `m` with `func`."
apply_leaf(m, partial(cond_init, func=func))
# Cell
def script_use_ctx(f):
"Decorator: create jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs): return sf(*args, *ctx.saved_variables, **kwargs)
return update_wrapper(_f,f)
# Cell
def script_save_ctx(static, *argidx):
"Decorator: create jit script and save args with indices `argidx` using `ctx.save_for_backward`"
def _dec(f):
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs):
if argidx:
save = [args[o] for o in argidx]
ctx.save_for_backward(*save)
if not argidx: args = [ctx]+args
return sf(*args, **kwargs)
if static: _f = staticmethod(_f)
return update_wrapper(_f,f)
return _dec
# Cell
def script_fwd(*argidx):
"Decorator: create static jit script and save args with indices `argidx` using `ctx.save_for_backward`"
return script_save_ctx(True, *argidx)
# Cell
def script_bwd(f):
"Decorator: create static jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
return staticmethod(script_use_ctx(f))
# Cell
def grad_module(cls):
"Decorator: convert `cls` into an autograd function"
class _c(nn.Module):
def forward(self, *args, **kwargs): return cls.apply(*args, **kwargs)
return _c
# Cell
def ismin_torch(min_version):
"Check if `torch.__version__` >= `min_version` using packaging.version"
return parse(torch.__version__) >= parse(min_version)
# Cell
def notmax_torch(max_version):
"Check if `torch.__version__` < `max_version` using packaging.version"
return parse(torch.__version__) < parse(max_version)
# Comes from 13b_metrics.ipynb, cell
def flatten_check(inp, targ):
"Check that `out` and `targ` have the same number of elements and flatten them."
inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
test_eq(len(inp), len(targ))
return inp,targ