generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 274
/
transform.py
237 lines (200 loc) · 9.5 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""Definition of `Transform` and `Pipeline`"""
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/05_transform.ipynb.
# %% auto 0
__all__ = ['Sig', 'Transform', 'InplaceTransform', 'DisplayedTransform', 'ItemTransform', 'get_func', 'Func', 'compose_tfms',
'mk_transform', 'gather_attrs', 'gather_attr_names', 'Pipeline']
# %% ../nbs/05_transform.ipynb
from .imports import *
from .foundation import *
from .utils import *
from .dispatch import *
import inspect
# %% ../nbs/05_transform.ipynb
_tfm_methods = 'encodes','decodes','setups'
def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)
class _TfmDict(dict):
def __setitem__(self, k, v):
if not _is_tfm_method(k, v): return super().__setitem__(k,v)
if k not in self: super().__setitem__(k,TypeDispatch())
self[k].add(v)
# %% ../nbs/05_transform.ipynb
class _TfmMeta(type):
def __new__(cls, name, bases, dict):
res = super().__new__(cls, name, bases, dict)
for nm in _tfm_methods:
base_td = [getattr(b,nm,None) for b in bases]
if nm in res.__dict__: getattr(res,nm).bases = base_td
else: setattr(res, nm, TypeDispatch(bases=base_td))
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
res.__signature__ = inspect.signature(res.__init__)
return res
def __call__(cls, *args, **kwargs):
f = first(args)
n = getattr(f, '__name__', None)
if _is_tfm_method(n, f):
getattr(cls,n).add(f)
return f
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
# instances of cls, fix it
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
return obj
@classmethod
def __prepare__(cls, name, bases): return _TfmDict()
# %% ../nbs/05_transform.ipynb
def _get_name(o):
if hasattr(o,'__qualname__'): return o.__qualname__
if hasattr(o,'__name__'): return o.__name__
return o.__class__.__name__
# %% ../nbs/05_transform.ipynb
def _is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')
# %% ../nbs/05_transform.ipynb
class Transform(metaclass=_TfmMeta):
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
split_idx,init_enc,order,train_setup = None,None,0,None
def __init__(self, enc=None, dec=None, split_idx=None, order=None):
self.split_idx = ifnone(split_idx, self.split_idx)
if order is not None: self.order=order
self.init_enc = enc or dec
if not self.init_enc: return
self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
if enc:
self.encodes.add(enc)
self.order = getattr(enc,'order',self.order)
if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
self._name = _get_name(enc)
if dec: self.decodes.add(dec)
@property
def name(self): return getattr(self, '_name', _get_name(self))
def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)
def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}'
def setup(self, items=None, train_setup=False):
train_setup = train_setup if self.train_setup is None else self.train_setup
return self.setups(getattr(items, 'train', items) if train_setup else items)
def _call(self, fn, x, split_idx=None, **kwargs):
if split_idx!=self.split_idx and self.split_idx is not None: return x
return self._do_call(getattr(self, fn), x, **kwargs)
def _do_call(self, f, x, **kwargs):
if not _is_tuple(x):
if f is None: return x
ret = f.returns(x) if hasattr(f,'returns') else None
return retain_type(f(x, **kwargs), x, ret)
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
return retain_type(res, x)
add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")
# %% ../nbs/05_transform.ipynb
class InplaceTransform(Transform):
"A `Transform` that modifies in-place and just returns whatever it's passed"
def _call(self, fn, x, split_idx=None, **kwargs):
super()._call(fn,x,split_idx,**kwargs)
return x
# %% ../nbs/05_transform.ipynb
class DisplayedTransform(Transform):
"A transform with a `__repr__` that shows its attrs"
@property
def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}"
# %% ../nbs/05_transform.ipynb
class ItemTransform(Transform):
"A transform that always take tuples as items"
_retain = True
def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
def decode(self, x, **kwargs): return self._call1(x, 'decode', **kwargs)
def _call1(self, x, name, **kwargs):
if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
y = getattr(super(), name)(list(x), **kwargs)
if not self._retain: return y
if is_listy(y) and not isinstance(y, tuple): y = tuple(y)
return retain_type(y, x)
# %% ../nbs/05_transform.ipynb
def get_func(t, name, *args, **kwargs):
"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
f = nested_callable(t, name)
return f if not (args or kwargs) else partial(f, *args, **kwargs)
# %% ../nbs/05_transform.ipynb
class Func():
"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs
def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'
def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)
def __call__(self,t): return mapped(self._get, t)
# %% ../nbs/05_transform.ipynb
class _Sig():
def __getattr__(self,k):
def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
return _inner
Sig = _Sig()
# %% ../nbs/05_transform.ipynb
def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order"
if reverse: tfms = reversed(tfms)
for f in tfms:
if not is_enc: f = f.decode
x = f(x, **kwargs)
return x
# %% ../nbs/05_transform.ipynb
def mk_transform(f):
"Convert function `f` to `Transform` if it isn't already one"
f = instantiate(f)
return f if isinstance(f,(Transform,Pipeline)) else Transform(f)
# %% ../nbs/05_transform.ipynb
def gather_attrs(o, k, nm):
"Used in __getattr__ to collect all attrs `k` from `self.{nm}`"
if k.startswith('_') or k==nm: raise AttributeError(k)
att = getattr(o,nm)
res = [t for t in att.attrgot(k) if t is not None]
if not res: raise AttributeError(k)
return res[0] if len(res)==1 else L(res)
# %% ../nbs/05_transform.ipynb
def gather_attr_names(o, nm):
"Used in __dir__ to collect all attrs `k` from `self.{nm}`"
return L(getattr(o,nm)).map(dir).concat().unique()
# %% ../nbs/05_transform.ipynb
class Pipeline:
"A pipeline of composed (for encode/decode) transforms, setup with types"
def __init__(self, funcs=None, split_idx=None):
self.split_idx,self.default = split_idx,None
if funcs is None: funcs = []
if isinstance(funcs, Pipeline): self.fs = funcs.fs
else:
if isinstance(funcs, Transform): funcs = [funcs]
self.fs = L(ifnone(funcs,[noop])).map(mk_transform).sorted(key='order')
for f in self.fs:
name = camel2snake(type(f).__name__)
a = getattr(self,name,None)
if a is not None: f = L(a)+f
setattr(self, name, f)
def setup(self, items=None, train_setup=False):
tfms = self.fs[:]
self.fs.clear()
for t in tfms: self.add(t,items, train_setup)
def add(self,ts, items=None, train_setup=False):
if not is_listy(ts): ts=[ts]
for t in ts: t.setup(items, train_setup)
self.fs+=ts
self.fs = self.fs.sorted(key='order')
def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
def __repr__(self): return f"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}"
def __getitem__(self,i): return self.fs[i]
def __setstate__(self,data): self.__dict__.update(data)
def __getattr__(self,k): return gather_attrs(self, k, 'fs')
def __dir__(self): return super().__dir__() + gather_attr_names(self, 'fs')
def decode (self, o, full=True):
if full: return compose_tfms(o, tfms=self.fs, is_enc=False, reverse=True, split_idx=self.split_idx)
#Not full means we decode up to the point the item knows how to show itself.
for f in reversed(self.fs):
if self._is_showable(o): return o
o = f.decode(o, split_idx=self.split_idx)
return o
def show(self, o, ctx=None, **kwargs):
o = self.decode(o, full=False)
o1 = (o,) if not _is_tuple(o) else o
if hasattr(o, 'show'): ctx = o.show(ctx=ctx, **kwargs)
else:
for o_ in o1:
if hasattr(o_, 'show'): ctx = o_.show(ctx=ctx, **kwargs)
return ctx
def _is_showable(self, o):
if hasattr(o, 'show'): return True
if _is_tuple(o): return all(hasattr(o_, 'show') for o_ in o)
return False