Skip to content

Commit

Permalink
Merge pull request #3644 from muellerzr/callback
Browse files Browse the repository at this point in the history
[Enhancement] New Callback event, before and after backward
  • Loading branch information
jph00 committed May 11, 2022
2 parents 2106700 + 2b6a0c6 commit eed715b
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 373 deletions.
19 changes: 10 additions & 9 deletions fastai/callback/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from __future__ import annotations


__all__ = ['CancelStepException', 'CancelFitException', 'CancelEpochException', 'CancelTrainException',
'CancelValidException', 'CancelBatchException', 'event', 'Callback', 'TrainEvalCallback',
'GatherPredsCallback', 'FetchPredsCallback']
__all__ = ['CancelStepException', 'CancelBackwardException', 'CancelFitException', 'CancelEpochException',
'CancelTrainException', 'CancelValidException', 'CancelBatchException', 'event', 'Callback',
'TrainEvalCallback', 'GatherPredsCallback', 'FetchPredsCallback']

# Cell
#nbdev_comment from __future__ import annotations
Expand All @@ -15,13 +15,13 @@
from ..losses import BaseLoss

# Cell
#nbdev_comment _all_ = ['CancelStepException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']
#nbdev_comment _all_ = ['CancelStepException','CancelBackwardException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']

# Cell
_events = L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \
before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
after_epoch after_cancel_fit after_fit')
before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
doc="All possible events as attributes to get tab-completion and typo-proofing")
Expand All @@ -30,7 +30,7 @@
#nbdev_comment _all_ = ['event']

# Cell
_inner_loop = "before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split()
_inner_loop = "before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()

# Cell
_ex_docs = dict(
Expand All @@ -39,6 +39,7 @@
CancelValidException="Skip the rest of the validation part of the epoch and go to `after_validate`",
CancelEpochException="Skip the rest of this epoch and go to `after_epoch`",
CancelStepException ="Skip stepping the optimizer",
CancelBackwardException="Skip the backward pass and go to `after_backward`",
CancelFitException ="Interrupts training and go to `after_fit`")

for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)
Expand All @@ -60,7 +61,7 @@ def __call__(self, event_name):
res = None
if self.run and _run:
try: res = getattr(self, event_name, noop)()
except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
except Exception as e:
e.args = [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}']
raise
Expand Down
13 changes: 6 additions & 7 deletions fastai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from __future__ import annotations


__all__ = ['CancelStepException', 'CancelFitException', 'CancelEpochException', 'CancelTrainException',
'CancelValidException', 'CancelBatchException', 'replacing_yield', 'mk_metric', 'save_model', 'load_model',
'Learner', 'before_batch_cb', 'load_learner', 'Metric', 'AvgMetric', 'AvgLoss', 'AvgSmoothLoss',
'ValueMetric', 'Recorder']
__all__ = ['CancelBackwardException', 'CancelStepException', 'CancelFitException', 'CancelEpochException',
'CancelTrainException', 'CancelValidException', 'CancelBatchException', 'replacing_yield', 'mk_metric',
'save_model', 'load_model', 'Learner', 'before_batch_cb', 'load_learner', 'Metric', 'AvgMetric', 'AvgLoss',
'AvgSmoothLoss', 'ValueMetric', 'Recorder']

# Cell
#nbdev_comment from __future__ import annotations
Expand All @@ -17,7 +17,7 @@
import pickle,threading

# Cell
#nbdev_comment _all_ = ['CancelStepException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']
#nbdev_comment _all_ = ['CancelBackwardException', 'CancelStepException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']

# Cell
defaults.lr = 1e-3
Expand Down Expand Up @@ -182,8 +182,7 @@ def _do_one_batch(self):
self.loss = self.loss_grad.clone()
self('after_loss')
if not self.training or not len(self.yb): return
self('before_backward')
self.loss_grad.backward()
self._with_events(self.loss_grad.backward, 'backward', CancelBackwardException)
self._with_events(self.opt.step, 'step', CancelStepException)
self.opt.zero_grad()

Expand Down
39 changes: 33 additions & 6 deletions nbs/13_callback.core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"outputs": [],
"source": [
"#|export\n",
"_all_ = ['CancelStepException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']"
"_all_ = ['CancelStepException','CancelBackwardException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']"
]
},
{
Expand Down Expand Up @@ -84,9 +84,9 @@
"source": [
"#|export\n",
"_events = L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \\\n",
" before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \\\n",
" after_train before_validate after_cancel_validate after_validate after_cancel_epoch \\\n",
" after_epoch after_cancel_fit after_fit')\n",
" before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \\\n",
" after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \\\n",
" after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')\n",
"\n",
"mk_class('event', **_events.map_dict(),\n",
" doc=\"All possible events as attributes to get tab-completion and typo-proofing\")"
Expand Down Expand Up @@ -158,7 +158,7 @@
"outputs": [],
"source": [
"#|export\n",
"_inner_loop = \"before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch\".split()"
"_inner_loop = \"before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch\".split()"
]
},
{
Expand All @@ -174,6 +174,7 @@
" CancelValidException=\"Skip the rest of the validation part of the epoch and go to `after_validate`\",\n",
" CancelEpochException=\"Skip the rest of this epoch and go to `after_epoch`\",\n",
" CancelStepException =\"Skip stepping the optimizer\",\n",
" CancelBackwardException=\"Skip the backward pass and go to `after_backward`\",\n",
" CancelFitException =\"Interrupts training and go to `after_fit`\")\n",
"\n",
"for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)"
Expand Down Expand Up @@ -202,7 +203,7 @@
" res = None\n",
" if self.run and _run: \n",
" try: res = getattr(self, event_name, noop)()\n",
" except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise\n",
" except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise\n",
" except Exception as e:\n",
" e.args = [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\\n\\t{e.args[0]}']\n",
" raise\n",
Expand Down Expand Up @@ -667,6 +668,32 @@
"show_doc(CancelBatchException, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"<h3 id=\"CancelBackwardException\" class=\"doc_header\"><code>class</code> <code>CancelBackwardException</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
"\n",
"> <code>CancelBackwardException</code>(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Skip the backward pass and go to `after_backward`"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(CancelBackwardException, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit eed715b

Please sign in to comment.