Skip to content

Commit 10d4e51

Browse files
pechyonkinsgugger
authored andcommitted
Register callback_fns in .get_preds() if they haven't been registered previously. (#2237)
* Move defaults.extra_callbacks registration to __post_init__ defaults.extra_callbacks should be registered in __post_init__ in order to allow using them with Learner without training. * Register callback_fns in get_preds If callback_fns were not registered previously in the learner, register them during get_preds(). This is useful when you want to run get_preds() right after creating the learner without training first.
1 parent fc5bf23 commit 10d4e51

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

fastai/basic_train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class Learner():
159159
layer_groups:Collection[nn.Module]=None
160160
add_time:bool=True
161161
silent:bool=None
162+
cb_fns_registered:bool=False
162163
def __post_init__(self)->None:
163164
"Setup path,metrics, callbacks and ensure model directory exists."
164165
self.path = Path(ifnone(self.path, self.data.path))
@@ -169,6 +170,7 @@ def __post_init__(self)->None:
169170
self.callbacks = listify(self.callbacks)
170171
if self.silent is None: self.silent = defaults.silent
171172
self.callback_fns = [partial(Recorder, add_time=self.add_time, silent=self.silent)] + listify(self.callback_fns)
173+
if defaults.extra_callbacks is not None: self.callbacks += defaults.extra_callbacks
172174

173175
def init(self, init): apply_init(self.model, init)
174176

@@ -196,7 +198,7 @@ def fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,
196198
if not getattr(self, 'opt', False): self.create_opt(lr, wd)
197199
else: self.opt.lr,self.opt.wd = lr,wd
198200
callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
199-
if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
201+
self.cb_fns_registered = True
200202
fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
201203

202204
def create_opt(self, lr:Floats, wd:Floats=0.)->None:
@@ -333,6 +335,12 @@ def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None,
333335
"Return predictions and targets on `ds_type` dataset."
334336
lf = self.loss_func if with_loss else None
335337
activ = ifnone(activ, _loss_func2activ(self.loss_func))
338+
if not self.cb_fns_registered:
339+
lr,wd = self.lr_range(defaults.lr),self.wd
340+
if not getattr(self, 'opt', False): self.create_opt(lr, wd)
341+
else: self.opt.lr,self.opt.wd = lr,wd
342+
self.callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(self.callbacks)
343+
self.cb_fns_registered = True
336344
return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
337345
activ=activ, loss_func=lf, n_batch=n_batch, pbar=pbar)
338346

0 commit comments

Comments
 (0)