/
torch_core.py
873 lines (745 loc) · 34.1 KB
/
torch_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_torch_core.ipynb.
# %% ../nbs/00_torch_core.ipynb 2
from __future__ import annotations
from .imports import *
from .torch_imports import *
from packaging.version import parse
# %% auto 0
__all__ = ['norm_types', 'setup_cuda', 'subplots', 'show_image', 'show_titled_image', 'show_images', 'ArrayBase',
'ArrayImageBase', 'ArrayImage', 'ArrayImageBW', 'ArrayMask', 'tensor', 'set_seed', 'get_random_states',
'set_random_states', 'no_random', 'unsqueeze', 'unsqueeze_', 'apply', 'maybe_gather', 'to_detach', 'to_half',
'to_float', 'default_device', 'to_device', 'to_cpu', 'to_np', 'to_concat', 'TensorBase', 'TensorImageBase',
'TensorImage', 'TensorImageBW', 'TensorMask', 'TensorFlowField', 'TensorCategory', 'TensorMultiCategory',
'TitledTensorScalar', 'concat', 'Chunks', 'show_title', 'ShowTitle', 'TitledInt', 'TitledFloat', 'TitledStr',
'TitledTuple', 'get_empty_df', 'display_df', 'get_first', 'one_param', 'item_find', 'find_device', 'find_bs',
'np_func', 'Module', 'get_model', 'one_hot', 'one_hot_decode', 'params', 'trainable_params',
'norm_bias_params', 'batch_to_samples', 'logit', 'num_distrib', 'rank_distrib', 'distrib_barrier',
'base_doc', 'doc', 'nested_reorder', 'flatten_check', 'make_cross_image', 'show_image_batch',
'requires_grad', 'init_default', 'cond_init', 'apply_leaf', 'apply_init', 'script_use_ctx',
'script_save_ctx', 'script_fwd', 'script_bwd', 'grad_module', 'ismin_torch', 'notmax_torch', 'progress_bar',
'master_bar']
# %% ../nbs/00_torch_core.ipynb 5
_all_ = ['progress_bar','master_bar']
# %% ../nbs/00_torch_core.ipynb 6
defaults.benchmark = True
# %% ../nbs/00_torch_core.ipynb 7
def setup_cuda(benchmark=defaults.benchmark):
"Sets the main cuda device and sets `cudnn.benchmark` to `benchmark`"
if torch.cuda.is_available():
if torch.cuda.current_device()==0:
def_gpu = int(os.environ.get('DEFAULT_GPU') or 0)
if torch.cuda.device_count()>=def_gpu: torch.cuda.set_device(def_gpu)
torch.backends.cudnn.benchmark = benchmark
# %% ../nbs/00_torch_core.ipynb 10
@delegates(plt.subplots, keep=True)
def subplots(
nrows:int=1, # Number of rows in returned axes grid
ncols:int=1, # Number of columns in returned axes grid
figsize:tuple=None, # Width, height in inches of the returned figure
imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure
suptitle:str=None, # Title to be set to returned figure
**kwargs
) -> (plt.Figure, plt.Axes): # Returns both fig and ax as a tuple
"Returns a figure and set of subplots to display images of `imsize` inches"
if figsize is None:
h=nrows*imsize if suptitle is None or imsize>2 else nrows*imsize+0.6 #https://github.com/matplotlib/matplotlib/issues/5355
figsize=(ncols*imsize, h)
fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
if suptitle is not None: fig.suptitle(suptitle)
if nrows*ncols==1: ax = array([ax])
return fig,ax
# %% ../nbs/00_torch_core.ipynb 13
def _fig_bounds(x):
r = x//32
return min(5, max(1,r))
# %% ../nbs/00_torch_core.ipynb 14
@delegates(plt.Axes.imshow, keep=True, but=['shape', 'imlim'])
def show_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):
"Show a PIL or PyTorch image on `ax`."
# Handle pytorch axis order
if hasattrs(im, ('data','cpu','permute')):
im = im.data.cpu()
if im.shape[0]<5: im=im.permute(1,2,0)
elif not isinstance(im,np.ndarray): im=array(im)
# Handle 1-channel images
if im.shape[-1]==1: im=im[...,0]
ax = ifnone(ax,ctx)
if figsize is None: figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1]))
if ax is None: _,ax = plt.subplots(figsize=figsize)
ax.imshow(im, **kwargs)
if title is not None: ax.set_title(title)
ax.axis('off')
return ax
# %% ../nbs/00_torch_core.ipynb 21
@delegates(show_image, keep=True)
def show_titled_image(o, **kwargs):
"Call `show_image` destructuring `o` to `(img,title)`"
show_image(o[0], title=str(o[1]), **kwargs)
# %% ../nbs/00_torch_core.ipynb 24
@delegates(subplots)
def show_images(ims, nrows=1, ncols=None, titles=None, **kwargs):
"Show all images `ims` as subplots with `rows` using `titles`."
if ncols is None: ncols = int(math.ceil(len(ims)/nrows))
if titles is None: titles = [None]*len(ims)
axs = subplots(nrows, ncols, **kwargs)[1].flat
for im,t,ax in zip(ims, titles, axs): show_image(im, ax=ax, title=t)
# %% ../nbs/00_torch_core.ipynb 27
class ArrayBase(ndarray):
"An `ndarray` that can modify casting behavior"
@classmethod
def _before_cast(cls, x): return x if isinstance(x,ndarray) else array(x)
# %% ../nbs/00_torch_core.ipynb 28
class ArrayImageBase(ArrayBase):
"Base class for arrays representing images"
_show_args = {'cmap':'viridis'}
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
# %% ../nbs/00_torch_core.ipynb 29
class ArrayImage(ArrayImageBase):
"An array representing an image"
pass
# %% ../nbs/00_torch_core.ipynb 30
class ArrayImageBW(ArrayImage):
"An array representing an image"
_show_args = {'cmap':'Greys'}
# %% ../nbs/00_torch_core.ipynb 31
class ArrayMask(ArrayImageBase):
"An array representing an image mask"
_show_args = {'alpha':0.5, 'cmap':'tab20', 'interpolation':'nearest'}
# %% ../nbs/00_torch_core.ipynb 37
@patch
def __array_eq__(self:Tensor,b):
return torch.equal(self,b) if self.dim() else self==b
# %% ../nbs/00_torch_core.ipynb 38
def _array2tensor(x, requires_grad=False, pin_memory=False, **kwargs):
if x.dtype==np.uint16: x = x.astype(np.float32)
# windows default numpy int dtype is int32, while torch tensor default int dtype is int64
# https://github.com/numpy/numpy/issues/9464
if sys.platform == "win32" and x.dtype==int: x = x.astype(np.int64)
t = torch.as_tensor(x, **kwargs)
t.requires_grad_(requires_grad)
if pin_memory: t.pin_memory()
return t
# %% ../nbs/00_torch_core.ipynb 39
@use_kwargs_dict(dtype=None, device=None, requires_grad=False, pin_memory=False)
def tensor(x, *rest, **kwargs):
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
if len(rest): x = (x,)+rest
# There was a Pytorch bug in dataloader using num_workers>0. Haven't confirmed if fixed
# if isinstance(x, (tuple,list)) and len(x)==0: return tensor(0)
res = (x if isinstance(x, Tensor)
else torch.tensor(x, **kwargs) if isinstance(x, (tuple,list,numbers.Number))
else _array2tensor(x, **kwargs) if isinstance(x, ndarray)
else as_tensor(x.values, **kwargs) if isinstance(x, (pd.Series, pd.DataFrame))
# else as_tensor(array(x, **kwargs)) if hasattr(x, '__array__') or is_iter(x)
else _array2tensor(array(x), **kwargs))
if res.dtype is torch.float64: return res.float()
return res
# %% ../nbs/00_torch_core.ipynb 42
def set_seed(s, reproducible=False):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try: torch.manual_seed(s)
except NameError: pass
try: torch.cuda.manual_seed_all(s)
except NameError: pass
try: np.random.seed(s%(2**32-1))
except NameError: pass
random.seed(s)
if reproducible:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# %% ../nbs/00_torch_core.ipynb 47
def get_random_states():
"Gets states for `random`, `torch`, and `numpy` random number generators"
return {'random_state':random.getstate(),
'numpy_state':np.random.get_state(),
'torch_state':torch.get_rng_state(),
'torch_cuda_state':torch.cuda.get_rng_state_all(),
'torch_deterministic':torch.backends.cudnn.deterministic,
'torch_benchmark':torch.backends.cudnn.benchmark}
# %% ../nbs/00_torch_core.ipynb 48
def set_random_states(random_state,numpy_state,torch_state,torch_cuda_state,torch_deterministic,torch_benchmark):
"Set states for `random`, `torch`, and `numpy` random number generators"
random.setstate(random_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)
torch.cuda.set_rng_state_all(torch_cuda_state)
torch.backends.cudnn.deterministic=torch_deterministic
torch.backends.cudnn.benchmark=torch_benchmark
# %% ../nbs/00_torch_core.ipynb 53
@contextmanager
def no_random(seed=42,reproducible=True):
"Stores and retrieves state of random number generators. Sets random seed for `random`, `torch`, and `numpy`."
states = get_random_states()
set_seed(seed,reproducible=reproducible)
try:
yield #we are managing global variables
finally:
set_random_states(**states)
# %% ../nbs/00_torch_core.ipynb 59
def unsqueeze(x, dim=-1, n=1):
"Same as `torch.unsqueeze` but can add `n` dims"
for _ in range(n): x = x.unsqueeze(dim)
return x
# %% ../nbs/00_torch_core.ipynb 61
def unsqueeze_(x, dim=-1, n=1):
"Same as `torch.unsqueeze_` but can add `n` dims"
for _ in range(n): x.unsqueeze_(dim)
return x
# %% ../nbs/00_torch_core.ipynb 63
def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2(*args, **kwargs))
def _fa_rebuild_qtensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs))
# %% ../nbs/00_torch_core.ipynb 64
def apply(func, x, *args, **kwargs):
"Apply `func` recursively to `x`, passing on args"
if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
if isinstance(x,dict): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
res = func(x, *args, **kwargs)
return res if x is None else retain_type(res, x)
# %% ../nbs/00_torch_core.ipynb 65
def maybe_gather(x, axis=0):
"Gather copies of `x` on `axis` (if training is distributed)"
if num_distrib()<=1: return x
ndim = x.ndim
res = [x.new_zeros(*x.shape if ndim > 0 else (1,)) for _ in range(num_distrib())]
torch.distributed.all_gather(res, x.contiguous() if ndim > 0 else x[None])
return torch.cat(res, dim=axis) if ndim > 0 else torch.cat(res, dim=axis).mean()
# %% ../nbs/00_torch_core.ipynb 66
def to_detach(b, cpu=True, gather=True):
"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`."
def _inner(x, cpu=True, gather=True):
if not isinstance(x,Tensor): return x
x = x.detach()
if gather: x = maybe_gather(x)
return x.cpu() if cpu else x
return apply(_inner, b, cpu=cpu, gather=gather)
# %% ../nbs/00_torch_core.ipynb 68
def to_half(b):
"Recursively map floating point tensors in `b ` to FP16."
return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)
# %% ../nbs/00_torch_core.ipynb 69
def to_float(b):
"Recursively map floating point tensors in `b ` to float."
return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)
# %% ../nbs/00_torch_core.ipynb 70
# None: True if available; True: error if not available; False: use CPU
defaults.use_cuda = None
# %% ../nbs/00_torch_core.ipynb 71
def _has_mps():
if nested_attr(torch, 'backends.mps.is_available', noop)(): return True
return getattr(torch, 'has_mps', False)
def default_device(use=-1):
"Return or set default device; `use_cuda`: -1 - CUDA/mps if available; True - error if not available; False - CPU"
if use == -1: use = defaults.use_cuda
else: defaults.use_cuda=use
if use is None:
if torch.cuda.is_available() or _has_mps(): use = True
if use:
if torch.cuda.is_available(): return torch.device(torch.cuda.current_device())
if _has_mps(): return torch.device('mps')
return torch.device('cpu')
# %% ../nbs/00_torch_core.ipynb 73
def to_device(b, device=None, non_blocking=False):
"Recursively put `b` on `device`."
if defaults.use_cuda==False: device='cpu'
elif device is None: device=default_device()
def _inner(o):
if isinstance(o,Tensor): return o.to(device, non_blocking=non_blocking)
# if hasattr(o, "to_device"): return o.to_device(device)
return o
return apply(_inner, b)
# %% ../nbs/00_torch_core.ipynb 76
def to_cpu(b):
"Recursively map tensors in `b ` to the cpu."
return to_device(b,'cpu')
# %% ../nbs/00_torch_core.ipynb 78
def to_np(x):
"Convert a tensor to a numpy array."
return apply(lambda o: o.data.cpu().numpy(), x)
# %% ../nbs/00_torch_core.ipynb 80
def to_concat(xs, dim=0):
"Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
if not xs: return xs
if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
if isinstance(xs[0],dict): return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
#We may receive xs that are not concatenable (inputs of a text classifier for instance),
# in this case we return a big list
try: return retain_type(torch.cat(xs, dim=dim), xs[0])
except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
for i in range_of(o_)) for o_ in xs], L())
# %% ../nbs/00_torch_core.ipynb 84
@patch
def set_meta(self:Tensor, x, as_copy=False):
"Set all metadata in `__dict__`"
if not hasattr(x,'__dict__'): return
# XXX: change to `deepcopy` once PyTorch 1.7.1 is out, and check nb 23 segmentation fit works
self.__dict__ = copy(x.__dict__) if as_copy else x.__dict__
# %% ../nbs/00_torch_core.ipynb 85
if not hasattr(torch,'as_subclass'): torch.as_subclass = torch.Tensor.as_subclass
# %% ../nbs/00_torch_core.ipynb 86
@patch
def as_subclass(self:Tensor, typ):
"Cast to `typ` and include `__dict__` and meta"
return retain_meta(self, torch.as_subclass(self, typ))
# %% ../nbs/00_torch_core.ipynb 89
def _torch_handled(args, opt, func):
if func not in opt: return False
for oks in opt[func]:
if all(isinstance(arg,ok) for arg,ok in zip(args,oks) if ok): return True
# %% ../nbs/00_torch_core.ipynb 90
# from https://github.com/pytorch/pytorch/blob/13c975684a220ec096216ec6468ccd0dc90ff50a/torch/_tensor.py#L34
def _rebuild_from_type(func, type, args, dict):
ret = func(*args).as_subclass(type)
ret.__dict__ = dict
return ret
# %% ../nbs/00_torch_core.ipynb 91
def _find_args(x):
x0 = x[0] if is_listy(x[0]) and x[0] else x
return [a for a in x0 if hasattr(a,'__dict__')]
# %% ../nbs/00_torch_core.ipynb 92
class TensorBase(Tensor):
"A `Tensor` which support subclass pickling, and maintains metadata when casting or after methods"
debug,_opt = False,defaultdict(list)
def __new__(cls, x, **kwargs):
res = cast(tensor(x), cls)
for k,v in kwargs.items(): setattr(res, k, v)
return res
@classmethod
def _before_cast(cls, x): return tensor(x)
def __repr__(self): return re.sub('tensor', self.__class__.__name__, super().__repr__())
def __reduce_ex__(self,proto):
torch.utils.hooks.warn_if_has_hooks(self)
args = (self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
args = args + (self.requires_grad, OrderedDict())
f = torch._utils._rebuild_qtensor if self.is_quantized else torch._utils._rebuild_tensor_v2
return (_rebuild_from_type, (f, type(self), args, self.__dict__))
@classmethod
def register_func(cls, func, *oks): cls._opt[func].append(oks)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)
elif dict_objs and is_listy(res): [r.set_meta(dict_objs[0],as_copy=True) for r in res if issubclass(type(r),TensorBase)]
return res
def new_tensor(self, size, dtype=None, device=None, requires_grad=False):
cls = type(self)
return self.as_subclass(Tensor).new_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)
def new_ones(self, data, dtype=None, device=None, requires_grad=False):
cls = type(self)
return self.as_subclass(Tensor).new_ones(data, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)
def new(self, x=None):
cls = type(self)
res = self.as_subclass(Tensor).new() if x is None else self.as_subclass(Tensor).new(x)
return res.as_subclass(cls)
def requires_grad_(self, requires_grad=True):
# Workaround https://github.com/pytorch/pytorch/issues/50219
self.requires_grad = requires_grad
return self
# %% ../nbs/00_torch_core.ipynb 104
class TensorImageBase(TensorBase):
_show_args = ArrayImageBase._show_args
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
# %% ../nbs/00_torch_core.ipynb 105
class TensorImage(TensorImageBase): pass
# %% ../nbs/00_torch_core.ipynb 106
class TensorImageBW(TensorImage): _show_args = ArrayImageBW._show_args
# %% ../nbs/00_torch_core.ipynb 107
class TensorMask(TensorImageBase):
_show_args = ArrayMask._show_args
def show(self, ctx=None, **kwargs):
codes = getattr(self, 'codes', None)
if codes is not None: kwargs = merge({'vmin': 0, 'vmax': len(codes)}, kwargs)
return super().show(ctx=ctx, **kwargs)
# %% ../nbs/00_torch_core.ipynb 108
for o in Tensor.__getitem__, Tensor.__ne__,Tensor.__eq__,Tensor.add,Tensor.sub,Tensor.mul,Tensor.div,Tensor.__rsub__,Tensor.__radd__,Tensor.matmul,Tensor.bmm:
TensorBase.register_func(o, TensorMask, TensorImageBase)
TensorBase.register_func(o, TensorImageBase, TensorMask)
TensorMask.register_func(torch.einsum, str, TensorImageBase, TensorMask)
TensorMask.register_func(torch.einsum, str, TensorMask, TensorImageBase)
# %% ../nbs/00_torch_core.ipynb 115
class TensorFlowField(TensorBase): pass
TensorImage.register_func(F.grid_sample, TensorImageBase, TensorFlowField)
# %% ../nbs/00_torch_core.ipynb 117
class TensorCategory(TensorBase): pass
TensorBase.register_func(Tensor.__getitem__, TensorImageBase, TensorCategory)
# %% ../nbs/00_torch_core.ipynb 119
class TensorMultiCategory(TensorCategory): pass
# %% ../nbs/00_torch_core.ipynb 120
class TitledTensorScalar(TensorBase):
"A tensor containing a scalar that has a `show` method"
def show(self, **kwargs): show_title(self.item(), **kwargs)
# %% ../nbs/00_torch_core.ipynb 122
@patch
def tensored(self:L):
"`mapped(tensor)`"
return self.map(tensor)
@patch
def stack(self:L, dim=0):
"Same as `torch.stack`"
return torch.stack(list(self.tensored()), dim=dim)
@patch
def cat (self:L, dim=0):
"Same as `torch.cat`"
return torch.cat (list(self.tensored()), dim=dim)
# %% ../nbs/00_torch_core.ipynb 131
def concat(*ls):
"Concatenate tensors, arrays, lists, or tuples"
if not len(ls): return []
it = ls[0]
if isinstance(it,torch.Tensor): res = torch.cat(ls)
elif isinstance(it,ndarray): res = np.concatenate(ls)
else:
res = itertools.chain.from_iterable(map(L,ls))
if isinstance(it,(tuple,list)): res = type(it)(res)
else: res = L(res)
return retain_type(res, it)
# %% ../nbs/00_torch_core.ipynb 133
class Chunks:
"Slice and int indexing into a list of lists"
def __init__(self, chunks, lens=None):
self.chunks = chunks
self.lens = L(map(len,self.chunks) if lens is None else lens)
self.cumlens = np.cumsum(0+self.lens)
self.totlen = self.cumlens[-1]
def __getitem__(self,i):
if isinstance(i,slice): return retain_type(self.getslice(i), old=self.chunks[0])
di,idx = self.doc_idx(i)
return retain_type(self.chunks[di][idx], old=self.chunks[0])
def getslice(self, i):
st_d,st_i = self.doc_idx(ifnone(i.start,0))
en_d,en_i = self.doc_idx(ifnone(i.stop,self.totlen+1))
res = [self.chunks[st_d][st_i:(en_i if st_d==en_d else sys.maxsize)]]
for b in range(st_d+1,en_d): res.append(self.chunks[b])
if st_d!=en_d and en_d<len(self.chunks): res.append(self.chunks[en_d][:en_i])
return concat(*res)
def doc_idx(self, i):
if i<0: i=self.totlen+i # count from end
docidx = np.searchsorted(self.cumlens, i+1)-1
cl = self.cumlens[docidx]
return docidx,i-cl
# %% ../nbs/00_torch_core.ipynb 138
def show_title(o, ax=None, ctx=None, label=None, color='black', **kwargs):
"Set title of `ax` to `o`, or print `o` if `ax` is `None`"
ax = ifnone(ax,ctx)
if ax is None: print(o)
elif hasattr(ax, 'set_title'):
t = ax.title.get_text()
if len(t) > 0: o = t+'\n'+str(o)
ax.set_title(o, color=color)
elif isinstance(ax, pd.Series):
while label in ax: label += '_'
ax = pd.concat([ax,pd.Series({label: o})])
return ax
# %% ../nbs/00_torch_core.ipynb 140
class ShowTitle:
"Base class that adds a simple `show`"
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledInt(Int, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledFloat(Float, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledStr(Str, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
class TitledTuple(fastuple, ShowTitle):
_show_args = {'label': 'text'}
def show(self, ctx=None, **kwargs):
"Show self"
return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))
add_docs(TitledInt, "An `int` with `show`"); add_docs(TitledStr, "An `str` with `show`");
add_docs(TitledFloat, "A `float` with `show`"); add_docs(TitledTuple, "A `fastuple` with `show`")
# %% ../nbs/00_torch_core.ipynb 147
@patch
def truncate(self:TitledStr, n):
"Truncate self to `n`"
words = self.split(' ')[:n]
return TitledStr(' '.join(words))
# %% ../nbs/00_torch_core.ipynb 149
if not hasattr(pd.DataFrame,'_old_init'): pd.DataFrame._old_init = pd.DataFrame.__init__
# %% ../nbs/00_torch_core.ipynb 150
@patch
def __init__(self:pd.DataFrame, data=None, index=None, columns=None, dtype=None, copy=None):
if data is not None and isinstance(data, Tensor): data = to_np(data)
self._old_init(data, index=index, columns=columns, dtype=dtype, copy=copy)
# %% ../nbs/00_torch_core.ipynb 151
def get_empty_df(n):
"Return `n` empty rows of a dataframe"
df = pd.DataFrame(index = range(n))
return [df.iloc[i] for i in range(n)]
# %% ../nbs/00_torch_core.ipynb 152
def display_df(df):
"Display `df` in a notebook or defaults to print"
try: from IPython.display import display, HTML
except: return print(df)
display(HTML(df.to_html()))
# %% ../nbs/00_torch_core.ipynb 153
def get_first(c):
"Get the first element of c, even if c is a dataframe"
return getattr(c, 'iloc', c)[0]
# %% ../nbs/00_torch_core.ipynb 154
def one_param(m):
"First parameter in `m`"
return first(m.parameters())
# %% ../nbs/00_torch_core.ipynb 155
def item_find(x, idx=0):
"Recursively takes the `idx`-th element of `x`"
if is_listy(x): return item_find(x[idx])
if isinstance(x,dict):
key = list(x.keys())[idx] if isinstance(idx, int) else idx
return item_find(x[key])
return x
# %% ../nbs/00_torch_core.ipynb 156
def find_device(b):
"Recursively search the device of `b`."
return item_find(b).device
# %% ../nbs/00_torch_core.ipynb 158
def find_bs(b):
"Recursively search the batch size of `b`."
res = item_find(b)
if not hasattr(res, "shape"): return len(b)
return res.shape[0]
# %% ../nbs/00_torch_core.ipynb 160
def np_func(f):
"Convert a function taking and returning numpy arrays to one taking and returning tensors"
def _inner(*args, **kwargs):
nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]
return tensor(f(*nargs, **kwargs))
functools.update_wrapper(_inner, f)
return _inner
# %% ../nbs/00_torch_core.ipynb 164
class Module(nn.Module, metaclass=PrePostInitMeta):
"Same as `nn.Module`, but no need for subclasses to call `super().__init__`"
def __pre_init__(self, *args, **kwargs): super().__init__()
def __init__(self): pass
# %% ../nbs/00_torch_core.ipynb 167
from torch.nn.parallel import DistributedDataParallel
# %% ../nbs/00_torch_core.ipynb 168
def get_model(model):
"Return the model maybe wrapped inside `model`."
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model
# %% ../nbs/00_torch_core.ipynb 169
def one_hot(x, c):
"One-hot encode `x` with `c` classes."
res = torch.zeros(c, dtype=torch.uint8)
if isinstance(x, Tensor) and x.numel()>0: res[x] = 1.
else: res[list(L(x, use_list=None))] = 1.
return res
# %% ../nbs/00_torch_core.ipynb 171
def one_hot_decode(x, vocab=None):
return L(vocab[i] if vocab else i for i,x_ in enumerate(x) if x_==1)
# %% ../nbs/00_torch_core.ipynb 173
def params(m):
"Return all parameters of `m`"
return [p for p in m.parameters()]
# %% ../nbs/00_torch_core.ipynb 174
def trainable_params(m):
"Return all trainable parameters of `m`"
return [p for p in m.parameters() if p.requires_grad]
# %% ../nbs/00_torch_core.ipynb 176
norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm)
# %% ../nbs/00_torch_core.ipynb 177
def norm_bias_params(m, with_bias=True):
"Return all bias and BatchNorm parameters"
if isinstance(m, norm_types): return L(m.parameters())
res = L(m.children()).map(norm_bias_params, with_bias=with_bias).concat()
if with_bias and getattr(m, 'bias', None) is not None: res.append(m.bias)
return res
# %% ../nbs/00_torch_core.ipynb 179
def batch_to_samples(b, max_n=10):
"'Transposes' a batch to (at most `max_n`) samples"
if isinstance(b, Tensor): return retain_types(list(b[:max_n]), [b])
else:
res = L(b).map(partial(batch_to_samples,max_n=max_n))
return retain_types(res.zip(), [b])
# %% ../nbs/00_torch_core.ipynb 181
@patch
def interp_1d(x:Tensor, xp, fp):
"Same as `np.interp`"
slopes = (fp[1:]-fp[:-1])/(xp[1:]-xp[:-1])
incx = fp[:-1] - (slopes*xp[:-1])
locs = (x[:,None]>=xp[None,:]).long().sum(1)-1
locs = locs.clamp(0,len(slopes)-1)
return slopes[locs]*x + incx[locs]
# %% ../nbs/00_torch_core.ipynb 183
@patch
def pca(x:Tensor, k=2):
"Compute PCA of `x` with `k` dimensions."
x = x-torch.mean(x,0)
U,S,V = torch.svd(x.t())
return torch.mm(x,U[:,:k])
# %% ../nbs/00_torch_core.ipynb 184
def logit(x):
"Logit of `x`, clamped to avoid inf."
x = x.clamp(1e-7, 1-1e-7)
return -(1/x-1).log()
# %% ../nbs/00_torch_core.ipynb 185
def num_distrib():
"Return the number of processes in distributed training (if applicable)."
return int(os.environ.get('WORLD_SIZE', 0))
# %% ../nbs/00_torch_core.ipynb 186
def rank_distrib():
"Return the distributed rank of this process (if applicable)."
return int(os.environ.get('RANK', 0))
# %% ../nbs/00_torch_core.ipynb 187
def distrib_barrier():
"Place a synchronization barrier in distributed training"
if num_distrib() > 1 and torch.distributed.is_initialized(): torch.distributed.barrier()
# %% ../nbs/00_torch_core.ipynb 189
# Saving arrays requires pytables - optional dependency
try: import tables
except: pass
# %% ../nbs/00_torch_core.ipynb 190
def _comp_filter(lib='lz4',lvl=3): return tables.Filters(complib=f'blosc:{lib}', complevel=lvl)
# %% ../nbs/00_torch_core.ipynb 191
@patch
def save_array(p:Path, o, complib='lz4', lvl=3):
"Save numpy array to a compressed `pytables` file, using compression level `lvl`"
if isinstance(o,Tensor): o = to_np(o)
with tables.open_file(p, mode='w', filters=_comp_filter(lib=complib,lvl=lvl)) as f: f.create_carray('/', 'data', obj=o)
# %% ../nbs/00_torch_core.ipynb 193
@patch
def load_array(p:Path):
"Save numpy array to a `pytables` file"
with tables.open_file(p, 'r') as f: return f.root.data.read()
# %% ../nbs/00_torch_core.ipynb 194
def base_doc(elt):
"Print a base documentation of `elt`"
name = getattr(elt, '__qualname__', getattr(elt, '__name__', ''))
print(f'{name}{inspect.signature(elt)}\n{inspect.getdoc(elt)}\n')
print('To get a prettier result with hyperlinks to source code and documentation, install nbdev: pip install nbdev')
# %% ../nbs/00_torch_core.ipynb 195
def doc(elt):
"Try to use doc form nbdev and fall back to `base_doc`"
try:
from nbdev.showdoc import doc
doc(elt)
except: base_doc(elt)
# %% ../nbs/00_torch_core.ipynb 196
def nested_reorder(t, idxs):
"Reorder all tensors in `t` using `idxs`"
if isinstance(t, (Tensor,L)): return t[idxs]
elif is_listy(t): return type(t)(nested_reorder(t_, idxs) for t_ in t)
if t is None: return t
raise TypeError(f"Expected tensor, tuple, list or L but got {type(t)}")
# %% ../nbs/00_torch_core.ipynb 198
def flatten_check(inp, targ):
"Check that `inp` and `targ` have the same number of elements and flatten them."
inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
test_eq(len(inp), len(targ))
return inp,targ
# %% ../nbs/00_torch_core.ipynb 201
def make_cross_image(bw=True):
"Create a tensor containing a cross image, either `bw` (True) or color"
if bw:
im = torch.zeros(5,5)
im[2,:] = 1.
im[:,2] = 1.
else:
im = torch.zeros(3,5,5)
im[0,2,:] = 1.
im[1,:,2] = 1.
return im
# %% ../nbs/00_torch_core.ipynb 204
def show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs):
"Display batch `b` in a grid of size `items` with `cols` width"
if items<cols: cols=items
rows = (items+cols-1) // cols
if figsize is None: figsize = (cols*3, rows*3)
fig,axs = plt.subplots(rows, cols, figsize=figsize)
for *o,ax in zip(*to_cpu(b), axs.flatten()): show(o, ax=ax, **kwargs)
# %% ../nbs/00_torch_core.ipynb 207
def requires_grad(m):
"Check if the first parameter of `m` requires grad or not"
ps = list(m.parameters())
return ps[0].requires_grad if len(ps)>0 else False
# %% ../nbs/00_torch_core.ipynb 209
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func:
if hasattr(m, 'weight'): func(m.weight)
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
return m
# %% ../nbs/00_torch_core.ipynb 211
def cond_init(m, func):
"Apply `init_default` to `m` unless it's a batchnorm module"
if (not isinstance(m, norm_types)) and requires_grad(m): init_default(m, func)
# %% ../nbs/00_torch_core.ipynb 213
def apply_leaf(m, f):
"Apply `f` to children of `m`."
c = m.children()
if isinstance(m, nn.Module): f(m)
for l in c: apply_leaf(l,f)
# %% ../nbs/00_torch_core.ipynb 215
def apply_init(m, func=nn.init.kaiming_normal_):
"Initialize all non-batchnorm layers of `m` with `func`."
apply_leaf(m, partial(cond_init, func=func))
# %% ../nbs/00_torch_core.ipynb 218
def script_use_ctx(f):
"Decorator: create jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs): return sf(*args, *ctx.saved_variables, **kwargs)
return update_wrapper(_f,f)
# %% ../nbs/00_torch_core.ipynb 219
def script_save_ctx(static, *argidx):
"Decorator: create jit script and save args with indices `argidx` using `ctx.save_for_backward`"
def _dec(f):
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs):
if argidx:
save = [args[o] for o in argidx]
ctx.save_for_backward(*save)
if not argidx: args = [ctx]+args
return sf(*args, **kwargs)
if static: _f = staticmethod(_f)
return update_wrapper(_f,f)
return _dec
# %% ../nbs/00_torch_core.ipynb 220
def script_fwd(*argidx):
"Decorator: create static jit script and save args with indices `argidx` using `ctx.save_for_backward`"
return script_save_ctx(True, *argidx)
# %% ../nbs/00_torch_core.ipynb 221
def script_bwd(f):
"Decorator: create static jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
return staticmethod(script_use_ctx(f))
# %% ../nbs/00_torch_core.ipynb 222
def grad_module(cls):
"Decorator: convert `cls` into an autograd function"
class _c(nn.Module):
def forward(self, *args, **kwargs): return cls.apply(*args, **kwargs)
return _c
# %% ../nbs/00_torch_core.ipynb 224
def ismin_torch(min_version):
"Check if `torch.__version__` >= `min_version` using packaging.version"
return parse(torch.__version__) >= parse(min_version)
# %% ../nbs/00_torch_core.ipynb 225
def notmax_torch(max_version):
"Check if `torch.__version__` < `max_version` using packaging.version"
return parse(torch.__version__) < parse(max_version)
# %% ../nbs/00_torch_core.ipynb 227
# PyTorch 1.13 introduced a Tensor Subclass string formatting bug
# Workaround from pending PyTorch PR: https://github.com/pytorch/pytorch/pull/82766
if ismin_torch('1.13') and notmax_torch('1.14'):
from torch.overrides import has_torch_function_unary, handle_torch_function
@patch
def __format__(self:Tensor, format_spec):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
if self.dim() == 0 and not self.is_meta and issubclass(type(self), Tensor):
return self.item().__format__(format_spec)
return object.__format__(self, format_spec)