Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] New Callback event, before and after backward #3644

Merged
merged 5 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions fastai/callback/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from __future__ import annotations


__all__ = ['CancelStepException', 'CancelFitException', 'CancelEpochException', 'CancelTrainException',
'CancelValidException', 'CancelBatchException', 'event', 'Callback', 'TrainEvalCallback',
'GatherPredsCallback', 'FetchPredsCallback']
__all__ = ['CancelStepException', 'CancelBackwardException', 'CancelFitException', 'CancelEpochException',
'CancelTrainException', 'CancelValidException', 'CancelBatchException', 'event', 'Callback',
'TrainEvalCallback', 'GatherPredsCallback', 'FetchPredsCallback']

# Cell
#nbdev_comment from __future__ import annotations
Expand All @@ -15,13 +15,13 @@
from ..losses import BaseLoss

# Cell
#nbdev_comment _all_ = ['CancelStepException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']
#nbdev_comment _all_ = ['CancelStepException','CancelBackwardException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']

# Cell
_events = L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \
before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
after_epoch after_cancel_fit after_fit')
before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
doc="All possible events as attributes to get tab-completion and typo-proofing")
Expand All @@ -30,7 +30,7 @@
#nbdev_comment _all_ = ['event']

# Cell
_inner_loop = "before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split()
_inner_loop = "before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()

# Cell
_ex_docs = dict(
Expand All @@ -39,6 +39,7 @@
CancelValidException="Skip the rest of the validation part of the epoch and go to `after_validate`",
CancelEpochException="Skip the rest of this epoch and go to `after_epoch`",
CancelStepException ="Skip stepping the optimizer",
CancelBackwardException="Skip the backward pass and go to `after_backward`",
CancelFitException ="Interrupts training and go to `after_fit`")

for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)
Expand All @@ -60,7 +61,7 @@ def __call__(self, event_name):
res = None
if self.run and _run:
try: res = getattr(self, event_name, noop)()
except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
except Exception as e:
e.args = [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}']
raise
Expand Down
3 changes: 1 addition & 2 deletions fastai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def _do_one_batch(self):
self.loss = self.loss_grad.clone()
self('after_loss')
if not self.training or not len(self.yb): return
self('before_backward')
self.loss_grad.backward()
self._with_events(self.loss_grad.backward, 'backward', CancelBackwardException)
self._with_events(self.opt.step, 'step', CancelStepException)
self.opt.zero_grad()

Expand Down
39 changes: 33 additions & 6 deletions nbs/13_callback.core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"outputs": [],
"source": [
"#|export\n",
"_all_ = ['CancelStepException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']"
"_all_ = ['CancelStepException','CancelBackwardException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']"
]
},
{
Expand Down Expand Up @@ -84,9 +84,9 @@
"source": [
"#|export\n",
"_events = L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \\\n",
" before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \\\n",
" after_train before_validate after_cancel_validate after_validate after_cancel_epoch \\\n",
" after_epoch after_cancel_fit after_fit')\n",
" before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \\\n",
" after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \\\n",
" after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')\n",
"\n",
"mk_class('event', **_events.map_dict(),\n",
" doc=\"All possible events as attributes to get tab-completion and typo-proofing\")"
Expand Down Expand Up @@ -158,7 +158,7 @@
"outputs": [],
"source": [
"#|export\n",
"_inner_loop = \"before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch\".split()"
"_inner_loop = \"before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch\".split()"
]
},
{
Expand All @@ -174,6 +174,7 @@
" CancelValidException=\"Skip the rest of the validation part of the epoch and go to `after_validate`\",\n",
" CancelEpochException=\"Skip the rest of this epoch and go to `after_epoch`\",\n",
" CancelStepException =\"Skip stepping the optimizer\",\n",
" CancelBackwardException=\"Skip the backward pass and go to `after_backward`\",\n",
" CancelFitException =\"Interrupts training and go to `after_fit`\")\n",
"\n",
"for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)"
Expand Down Expand Up @@ -202,7 +203,7 @@
" res = None\n",
" if self.run and _run: \n",
" try: res = getattr(self, event_name, noop)()\n",
" except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise\n",
" except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise\n",
" except Exception as e:\n",
" e.args = [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\\n\\t{e.args[0]}']\n",
" raise\n",
Expand Down Expand Up @@ -667,6 +668,32 @@
"show_doc(CancelBatchException, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"<h3 id=\"CancelBackwardException\" class=\"doc_header\"><code>class</code> <code>CancelBackwardException</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
"\n",
"> <code>CancelBackwardException</code>(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Skip the backward pass and go to `after_backward`"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(CancelBackwardException, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading