Skip to content

Commit

Permalink
Register callback_fns in .get_preds() if they haven't been regist…
Browse files Browse the repository at this point in the history
…ered previously. (#2237)

* Move defaults.extra_callbacks registration to __post_init__

defaults.extra_callbacks should be registered in __post_init__
in order to allow using them with Learner without training.

* Register callback_fns in get_preds

If callback_fns were not registered previously in the learner,
register them during get_preds(). This is useful when you want
to run get_preds() right after creating the learner without
training first.
  • Loading branch information
pechyonkin authored and sgugger committed Jul 22, 2019
1 parent fc5bf23 commit 10d4e51
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion fastai/basic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class Learner():
layer_groups:Collection[nn.Module]=None
add_time:bool=True
silent:bool=None
cb_fns_registered:bool=False
def __post_init__(self)->None:
"Setup path,metrics, callbacks and ensure model directory exists."
self.path = Path(ifnone(self.path, self.data.path))
Expand All @@ -169,6 +170,7 @@ def __post_init__(self)->None:
self.callbacks = listify(self.callbacks)
if self.silent is None: self.silent = defaults.silent
self.callback_fns = [partial(Recorder, add_time=self.add_time, silent=self.silent)] + listify(self.callback_fns)
if defaults.extra_callbacks is not None: self.callbacks += defaults.extra_callbacks

def init(self, init): apply_init(self.model, init)

Expand Down Expand Up @@ -196,7 +198,7 @@ def fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,
if not getattr(self, 'opt', False): self.create_opt(lr, wd)
else: self.opt.lr,self.opt.wd = lr,wd
callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
self.cb_fns_registered = True
fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)

def create_opt(self, lr:Floats, wd:Floats=0.)->None:
Expand Down Expand Up @@ -333,6 +335,12 @@ def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None,
"Return predictions and targets on `ds_type` dataset."
lf = self.loss_func if with_loss else None
activ = ifnone(activ, _loss_func2activ(self.loss_func))
if not self.cb_fns_registered:
lr,wd = self.lr_range(defaults.lr),self.wd
if not getattr(self, 'opt', False): self.create_opt(lr, wd)
else: self.opt.lr,self.opt.wd = lr,wd
self.callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(self.callbacks)
self.cb_fns_registered = True
return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
activ=activ, loss_func=lf, n_batch=n_batch, pbar=pbar)

Expand Down

0 comments on commit 10d4e51

Please sign in to comment.