/
learner.py
225 lines (182 loc) · 7.37 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_learner.ipynb.
# %% auto 0
__all__ = ['device', 'CancelFitException', 'CancelEpochException', 'CancelBatchException', 'PublishEvents', 'Learner',
'Subscriber', 'MetricsS', 'DeviceS', 'LRFindS', 'MomentumLearner', 'ProgressS']
# %% ../nbs/02_learner.ipynb 2
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
import torch.nn.functional as F
from operator import attrgetter
import fastcore.all as fc
import math
from fastprogress import progress_bar,master_bar
import torcheval.metrics as tem
import matplotlib.pyplot as plt
# %% ../nbs/02_learner.ipynb 3
class CancelFitException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass
class PublishEvents():
def __init__(self, name):
self.name = name
def __call__(self, decorated_fn):
def decorated_fn_with_publishing(learner, *args, **kwargs):
try:
learner.publish(f'before_{self.name}')
decorated_fn(learner, *args, **kwargs)
learner.publish(f'after_{self.name}')
except globals()[f'Cancel{self.name.title()}Exception']: pass
return decorated_fn_with_publishing
# %% ../nbs/02_learner.ipynb 4
class Learner():
def __init__(self, model, dls, loss_fn, optim_class, lr, subs):
self.model = model
self.dls = dls
self.loss_fn = loss_fn
self.optim_class = optim_class
self.lr = lr
self.subs = subs
def fit(self, epochs, train=True, valid=True, subs=[], lr=None):
for sub in subs: self.subs.append(sub)
self.n_epochs = epochs
self.epochs = range(self.n_epochs)
lr = self.lr if lr is None else lr
self.opt = self.optim_class(self.model.parameters(), lr)
try:
self._fit(train, valid)
finally:
for sub in subs: self.subs.remove(sub)
@PublishEvents('fit')
def _fit(self, train, valid):
for self.epoch in self.epochs:
if train:
self.one_epoch(True)
if valid:
with torch.no_grad():
self.one_epoch(False)
def one_epoch(self, train):
self.model.train(train)
self.dl = self.dls.train if train else self.dls.valid
self._one_epoch()
@PublishEvents('epoch')
def _one_epoch(self):
for self.batch in self.dl:
self.one_batch()
@PublishEvents('batch')
def one_batch(self):
self.predict()
self.publish('after_predict')
self.get_loss()
self.publish('after_loss')
if self.model.training:
self.backward()
self.publish('after_backward')
self.step()
self.publish('after_step')
self.zero_grad()
def publish(self, event):
for sub in sorted(self.subs, key=attrgetter('order')):
method = getattr(sub, event, None)
if method is not None: method(self)
def predict(self):
self.preds = self.model(self.batch[0])
def get_loss(self):
self.loss = self.loss_fn(self.preds, self.batch[1])
def backward(self): self.loss.backward()
def step(self): self.opt.step()
def zero_grad(self): self.opt.zero_grad()
# %% ../nbs/02_learner.ipynb 5
class Subscriber():
order = 0
# %% ../nbs/02_learner.ipynb 6
class MetricsS(Subscriber):
def __init__(self, **metrics):
self.metrics = metrics
self.loss = tem.Mean()
def before_fit(self, learn):
learn.metrics = self
def before_epoch(self, learn):
for m in self.metrics.values(): m.reset()
self.loss.reset()
def after_batch(self, learn):
x,y,*_ = self.to_cpu(learn.batch)
for m in self.metrics.values(): m.update(self.to_cpu(learn.preds), y)
self.loss.update(self.to_cpu(learn.loss), weight=len(x))
def after_epoch(self, learn):
log = {
'epoch': learn.epoch,
'mode': 'train' if learn.model.training else 'eval',
'loss' : f'{self.loss.compute():.3f}'
}
for k, v in self.metrics.items():
log[k] = f'{v.compute():.3f}'
self.output(log)
def to_cpu(self, x):
if isinstance(x, list): return (self.to_cpu(el) for el in x)
return x.detach().cpu()
def output(self, log): print(log)
# %% ../nbs/02_learner.ipynb 7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DeviceS(Subscriber):
def __init__(self, device):
self.device = device
def before_fit(self, learn):
learn.model.to(self.device)
def before_batch(self, learn):
learn.batch = [x.to(self.device) for x in learn.batch]
# %% ../nbs/02_learner.ipynb 8
class LRFindS(Subscriber):
def __init__(self, mult=1.25):
self.mult = mult
self.min = math.inf
def before_epoch(self, learn):
if not learn.model.training: raise CancelFitException
self.losses = []
self.lrs = []
def after_loss(self, learn):
lr = learn.opt.param_groups[0]['lr']
self.lrs.append(lr)
loss = learn.loss.detach().cpu()
self.losses.append(loss)
if loss < self.min: self.min = loss
if loss > self.min*3: raise CancelFitException()
for g in learn.opt.param_groups: g['lr'] = lr * self.mult
def plot(self):
plt.plot(self.lrs, self.losses)
plt.xscale('log')
# %% ../nbs/02_learner.ipynb 9
class MomentumLearner(Learner):
def __init__(self, model, dls, loss_fn, optim_class, lr, subs, mom=0.85):
self.mom = mom
super().__init__(model, dls, loss_fn, optim_class, lr, subs)
def zero_grad(self):
with torch.no_grad():
for p in self.model.parameters(): p.grad *= self.mom
# %% ../nbs/02_learner.ipynb 10
class ProgressS(Subscriber):
order = MetricsS.order+1
def __init__(self, plot=False): self.plot = plot
def before_fit(self, learn):
learn.epochs = self.mbar = master_bar(learn.epochs)
self.first = True
if hasattr(learn, 'metrics'): learn.metrics.output = self.output
self.losses = []
self.val_losses = []
def output(self, d):
if self.first:
self.mbar.write(list(d), table=True)
self.first = False
self.mbar.write(list(d.values()), table=True)
def before_epoch(self, learn): learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
def after_batch(self, learn):
learn.dl.comment = f'{learn.loss:.3f}'
if self.plot and hasattr(learn, 'metrics') and learn.model.training:
self.losses.append(learn.loss.item())
def after_epoch(self, learn):
if not learn.model.training:
if self.plot and hasattr(learn, 'metrics'):
if self.val_losses:
self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])
self.val_losses.append(learn.metrics.loss.compute())
self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch+1).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])