forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 1
/
learner.py
173 lines (132 loc) · 6.83 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from .imports import *
from .torch_imports import *
from .core import *
from .transforms import *
from .model import *
from .dataset import *
from .sgdr import *
from .layer_optimizer import *
from .layers import *
from .metrics import *
from .losses import *
import time
class BasicModel():
def __init__(self,model,name='unnamed'): self.model,self.name = model,name
def get_layer_groups(self, do_fc=False): return children(self.model)
class SingleModel(BasicModel):
def get_layer_groups(self): return [self.model]
class Learner():
def __init__(self, data, models, opt_fn=None, tmp_name='tmp', models_name='models', metrics=None):
self.data_,self.models,self.metrics = data,models,metrics
self.sched=None
self.clip = None
self.opt_fn = opt_fn or SGD_Momentum(0.9)
self.tmp_path = os.path.join(self.data.path, tmp_name)
self.models_path = os.path.join(self.data.path, models_name)
os.makedirs(self.tmp_path, exist_ok=True)
os.makedirs(self.models_path, exist_ok=True)
self.crit,self.reg_fn,self.crit = None,None,None
def __getitem__(self,i): return self.children[i]
@property
def children(self): return children(self.model)
@property
def model(self): return self.models.model
@property
def data(self): return self.data_
def summary(self): return model_summary(self.model, [3,self.data.sz,self.data.sz])
def __repr__(self): return self.model.__repr__()
def set_bn_freeze(self, m, do_freeze):
if hasattr(m, 'running_mean'): m.bn_freeze = do_freeze
def bn_freeze(self, do_freeze):
apply_leaf(self.model, lambda m: self.set_bn_freeze(m, do_freeze))
def freeze_to(self, n):
c=self.get_layer_groups()
for l in c: set_trainable(l, False)
for l in c[n:]: set_trainable(l, True)
def unfreeze(self): self.freeze_to(0)
def get_model_path(self, name): return os.path.join(self.models_path,name)+'.h5'
def save(self, name): save_model(self.model, self.get_model_path(name))
def load(self, name): load_model(self.model, self.get_model_path(name))
def set_data(self, data): self.data_ = data
def get_cycle_end(self, name):
if name is None: return None
return lambda sched, cycle: self.save_cycle(name, cycle)
def save_cycle(self, name, cycle): self.save(f'{name}_cyc_{cycle}')
def load_cycle(self, name, cycle): self.load(f'{name}_cyc_{cycle}')
def fit_gen(self, model, data, layer_opt, n_cycle, cycle_len=None, cycle_mult=1, cycle_save_name=None,
metrics=None, callbacks=None, **kwargs):
if callbacks is None: callbacks=[]
if metrics is None: metrics=self.metrics
if cycle_len:
cycle_end = self.get_cycle_end(cycle_save_name)
cycle_batches = len(data.trn_dl)*cycle_len
self.sched = CosAnneal(layer_opt, cycle_batches, on_cycle_end=cycle_end, cycle_mult=cycle_mult)
elif not self.sched: self.sched=LossRecorder(layer_opt)
callbacks+=[self.sched]
for cb in callbacks: cb.on_train_begin()
n_epoch = sum_geom(cycle_len if cycle_len else 1, cycle_mult, n_cycle)
fit(model, data, n_epoch, layer_opt.opt, self.crit,
metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)
def get_layer_groups(self): return self.models.get_layer_groups()
def get_layer_opt(self, lrs, wds):
return LayerOptimizer(self.opt_fn, self.get_layer_groups(), lrs, wds)
def fit(self, lrs, n_cycle, wds=None, **kwargs):
self.sched = None
layer_opt = self.get_layer_opt(lrs, wds)
self.fit_gen(self.model, self.data, layer_opt, n_cycle, **kwargs)
def lr_find(self, start_lr=1e-5, end_lr=10, wds=None):
"""Helps you find an optimal learning rate for a model.
It uses the technique developed in the 2015 paper
`Cyclical Learning Rates for Training Neural Networks`, where
we simply keep increasing the learning rate from a very small value,
until the loss starts decreasing.
Args:
start_lr (float/numpy array) : Passing in a numpy array allows you
to specify learning rates for a learner's layer_groups
end_lr (float) : The maximum learning rate to try.
wds (iterable/float)
Examples:
As training moves us closer to the optimal weights for a model,
the optimal learning rate will be smaller. We can take advantage of
that knowledge and provide lr_find() with a starting learning rate
1000x smaller than the model's current learning rate as such:
>> learn.lr_find(lr/1000)
>> lrs = np.array([ 1e-4, 1e-3, 1e-2 ])
>> learn.lr_find(lrs / 1000)
Notes:
lr_find() may finish before going through each batch of examples if
the loss decreases enough.
.. _Cyclical Learning Rates for Training Neural Networks:
http://arxiv.org/abs/1506.01186
"""
self.save('tmp')
layer_opt = self.get_layer_opt(start_lr, wds)
self.sched = LR_Finder(layer_opt, len(self.data.trn_dl), end_lr)
self.fit_gen(self.model, self.data, layer_opt, 1)
self.load('tmp')
def predict(self, is_test=False): return self.predict_with_targs(is_test)[0]
def predict_with_targs(self, is_test=False):
dl = self.data.test_dl if is_test else self.data.val_dl
return predict_with_targs(self.model, dl)
def predict_dl(self, dl): return predict_with_targs(self.model, dl)[0]
def predict_array(self, arr): return to_np(self.model(V(T(arr).cuda())))
def TTA(self, n_aug=4, is_test=False):
""" Predict with Test Time Augmentation (TTA)
Additional to the original test/validation images, apply image augmentation to them
(just like for training images) and calculate the mean of predictions. The intent
is to increase the accuracy of predictions by examining the images using multiple
perspectives.
Args:
n_aug: a number of augmentation images to use per original image
is_test: indicate to use test images; otherwise use validation images
Returns:
(tuple): a tuple containing:
log predictions (numpy.ndarray): log predictions (i.e. `np.exp(log_preds)` will return probabilities)
targs (numpy.ndarray): target values when `is_test==False`; zeros otherwise.
"""
dl1 = self.data.test_dl if is_test else self.data.val_dl
dl2 = self.data.test_aug_dl if is_test else self.data.aug_dl
preds1,targs = predict_with_targs(self.model, dl1)
preds1 = [preds1]*math.ceil(n_aug/4)
preds2 = [predict_with_targs(self.model, dl2)[0] for i in tqdm(range(n_aug), leave=False)]
return np.stack(preds1+preds2).mean(0), targs