Skip to content

Commit

Permalink
fixes #3739
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Jul 16, 2022
1 parent bf43072 commit c31cd24
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 28 deletions.
25 changes: 20 additions & 5 deletions fastai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,32 @@ def before_epoch(self):
# Cell
class Learner(GetAttr):
_default='model'
def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,
metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
moms=(0.95,0.85,0.95)):
def __init__(self,
dls, # `DataLoaders` containing data for each dataset needed for `model`
model:callable, # The model to train or use for inference
loss_func:callable|None=None, # Loss function for training
opt_func=Adam, # Optimisation function for training
lr=defaults.lr, # Learning rate
splitter:callable=trainable_params, # Used to split parameters into layer groups
cbs=None, # Callbacks
metrics=None, # Printed after each epoch
path=None, # Parent directory to save, load, and export models
model_dir='models', # Subdirectory to save and load models
wd=None, # Weight decay
wd_bn_bias=False, # Apply weight decay to batchnorm bias params?
train_bn=True, # Always train batchnorm layers?
moms=(0.95,0.85,0.95), # Momentum
default_cbs:bool=True # Include default callbacks?
):
path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))
if loss_func is None:
loss_func = getattr(dls.train_ds, 'loss_func', None)
assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
self.dls,self.model = dls,model
store_attr(but='dls,model,cbs')
self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()
self.add_cbs(L(defaults.callbacks)+L(cbs))
if default_cbs: self.add_cbs(L(defaults.callbacks))
self.add_cbs(cbs)
self.lock = threading.Lock()
self("after_create")

Expand Down Expand Up @@ -413,7 +428,7 @@ def load_learner(fname, cpu=True, pickle_module=pickle):
map_loc = 'cpu' if cpu else default_device()
try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)
except AttributeError as e:
e.args = [f"Custom classes or functions exported with your `Learner` are not available in the namespace currently.\nPlease re-declare or import them before calling `load_learner`:\n\t{e.args[0]}"]
e.args = [f"Custom classes or functions exported with your `Learner` not available in namespace.\Re-declare/import before loading:\n\t{e.args[0]}"]
raise
if cpu:
res.dls.cpu()
Expand Down
100 changes: 77 additions & 23 deletions nbs/13a_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,32 @@
"#|export\n",
"class Learner(GetAttr):\n",
" _default='model'\n",
" def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,\n",
" metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,\n",
" moms=(0.95,0.85,0.95)):\n",
" def __init__(self,\n",
" dls, # `DataLoaders` containing data for each dataset needed for `model`\n",
" model:callable, # The model to train or use for inference\n",
" loss_func:callable|None=None, # Loss function for training\n",
" opt_func=Adam, # Optimisation function for training\n",
" lr=defaults.lr, # Learning rate\n",
" splitter:callable=trainable_params, # Used to split parameters into layer groups\n",
" cbs=None, # Callbacks\n",
" metrics=None, # Printed after each epoch\n",
" path=None, # Parent directory to save, load, and export models\n",
" model_dir='models', # Subdirectory to save and load models\n",
" wd=None, # Weight decay\n",
" wd_bn_bias=False, # Apply weight decay to batchnorm bias params?\n",
" train_bn=True, # Always train batchnorm layers?\n",
" moms=(0.95,0.85,0.95), # Momentum\n",
" default_cbs:bool=True # Include default callbacks?\n",
" ):\n",
" path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))\n",
" if loss_func is None:\n",
" loss_func = getattr(dls.train_ds, 'loss_func', None)\n",
" assert loss_func is not None, \"Could not infer loss function from the data, please pass a loss function.\"\n",
" self.dls,self.model = dls,model\n",
" store_attr(but='dls,model,cbs')\n",
" self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()\n",
" self.add_cbs(L(defaults.callbacks)+L(cbs))\n",
" if default_cbs: self.add_cbs(L(defaults.callbacks))\n",
" self.add_cbs(cbs)\n",
" self.lock = threading.Lock()\n",
" self(\"after_create\")\n",
"\n",
Expand Down Expand Up @@ -596,9 +611,27 @@
"text/markdown": [
"<h2 id=\"Learner\" class=\"doc_header\"><code>class</code> <code>Learner</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
"\n",
"> <code>Learner</code>(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*) :: [`GetAttr`](https://fastcore.fast.ai/basics#GetAttr)\n",
"\n",
"Group together a `model`, some `dls` and a `loss_func` to handle training"
"> <code>Learner</code>(**`dls`**, **`model`**:`callable`, **`loss_func`**:`callable | None`=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**:`callable`=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*, **`default_cbs`**:`bool`=*`True`*) :: [`GetAttr`](https://fastcore.fast.ai/basics#GetAttr)\n",
"\n",
"Group together a `model`, some `dls` and a `loss_func` to handle training\n",
"\n",
"||Type|Default|Details|\n",
"|---|---|---|---|\n",
"|**`dls`**|||[`DataLoaders`](/data.core.html#DataLoaders) containing data for each dataset needed for `model`|\n",
"|**`model`**|`callable`||The model to train or use for inference|\n",
"|**`loss_func`**|`callable or None`|`None`|Loss function for training|\n",
"|**`opt_func`**|`function`|[`Adam`](/optimizer.html#Adam)|Optimisation function for training|\n",
"|**`lr`**|`float`|`0.001`|Learning rate|\n",
"|**`splitter`**|`callable`|[`trainable_params`](/torch_core.html#trainable_params)|Used to split parameters into layer groups|\n",
"|**`cbs`**||`None`|Callbacks|\n",
"|**`metrics`**||`None`|Printed after each epoch|\n",
"|**`path`**||`None`|Parent directory to save, load, and export models|\n",
"|**`model_dir`**|`str`|`models`|Subdirectory to save and load models|\n",
"|**`wd`**||`None`|Weight decay|\n",
"|**`wd_bn_bias`**|`bool`|`False`|Apply weight decay to batchnorm bias params?|\n",
"|**`train_bn`**|`bool`|`True`|Always train batchnorm layers?|\n",
"|**`moms`**|`tuple`|`(0.95, 0.85, 0.95)`|Momentum|\n",
"|**`default_cbs`**|`bool`|`True`|Include default callbacks?|\n"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
Expand Down Expand Up @@ -681,7 +714,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.fit\" class=\"doc_header\"><code>Learner.fit</code><a href=\"__main__.py#L136\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.fit\" class=\"doc_header\"><code>Learner.fit</code><a href=\"__main__.py#L152\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.fit</code>(**`n_epoch`**, **`lr`**=*`None`*, **`wd`**=*`None`*, **`cbs`**=*`None`*, **`reset_opt`**=*`False`*, **`start_epoch`**=*`0`*)\n",
"\n",
Expand Down Expand Up @@ -861,7 +894,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.one_batch\" class=\"doc_header\"><code>Learner.one_batch</code><a href=\"__main__.py#L112\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.one_batch\" class=\"doc_header\"><code>Learner.one_batch</code><a href=\"__main__.py#L128\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.one_batch</code>(**`i`**, **`b`**)\n",
"\n",
Expand Down Expand Up @@ -1003,7 +1036,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.all_batches\" class=\"doc_header\"><code>Learner.all_batches</code><a href=\"__main__.py#L87\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.all_batches\" class=\"doc_header\"><code>Learner.all_batches</code><a href=\"__main__.py#L102\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.all_batches</code>()\n",
"\n",
Expand Down Expand Up @@ -1103,7 +1136,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.create_opt\" class=\"doc_header\"><code>Learner.create_opt</code><a href=\"__main__.py#L68\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.create_opt\" class=\"doc_header\"><code>Learner.create_opt</code><a href=\"__main__.py#L83\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.create_opt</code>()\n",
"\n",
Expand Down Expand Up @@ -1237,7 +1270,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.__call__\" class=\"doc_header\"><code>Learner.__call__</code><a href=\"__main__.py#L61\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.__call__\" class=\"doc_header\"><code>Learner.__call__</code><a href=\"__main__.py#L76\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.__call__</code>(**`event_name`**)\n",
"\n",
Expand Down Expand Up @@ -1289,7 +1322,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.add_cb\" class=\"doc_header\"><code>Learner.add_cb</code><a href=\"__main__.py#L33\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.add_cb\" class=\"doc_header\"><code>Learner.add_cb</code><a href=\"__main__.py#L48\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.add_cb</code>(**`cb`**)\n",
"\n",
Expand Down Expand Up @@ -1328,7 +1361,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.add_cbs\" class=\"doc_header\"><code>Learner.add_cbs</code><a href=\"__main__.py#L25\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.add_cbs\" class=\"doc_header\"><code>Learner.add_cbs</code><a href=\"__main__.py#L40\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.add_cbs</code>(**`cbs`**)\n",
"\n",
Expand Down Expand Up @@ -1364,7 +1397,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.added_cbs\" class=\"doc_header\"><code>Learner.added_cbs</code><a href=\"__main__.py#L48\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.added_cbs\" class=\"doc_header\"><code>Learner.added_cbs</code><a href=\"__main__.py#L63\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.added_cbs</code>(**`cbs`**)\n",
"\n",
Expand Down Expand Up @@ -1402,7 +1435,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.ordered_cbs\" class=\"doc_header\"><code>Learner.ordered_cbs</code><a href=\"__main__.py#L60\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.ordered_cbs\" class=\"doc_header\"><code>Learner.ordered_cbs</code><a href=\"__main__.py#L75\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.ordered_cbs</code>(**`event`**)\n",
"\n",
Expand Down Expand Up @@ -1457,7 +1490,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.remove_cb\" class=\"doc_header\"><code>Learner.remove_cb</code><a href=\"__main__.py#L40\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.remove_cb\" class=\"doc_header\"><code>Learner.remove_cb</code><a href=\"__main__.py#L55\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.remove_cb</code>(**`cb`**)\n",
"\n",
Expand Down Expand Up @@ -1518,7 +1551,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.remove_cbs\" class=\"doc_header\"><code>Learner.remove_cbs</code><a href=\"__main__.py#L29\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.remove_cbs\" class=\"doc_header\"><code>Learner.remove_cbs</code><a href=\"__main__.py#L44\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.remove_cbs</code>(**`cbs`**)\n",
"\n",
Expand Down Expand Up @@ -1564,7 +1597,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.removed_cbs\" class=\"doc_header\"><code>Learner.removed_cbs</code><a href=\"__main__.py#L54\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.removed_cbs\" class=\"doc_header\"><code>Learner.removed_cbs</code><a href=\"__main__.py#L69\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.removed_cbs</code>(**`cbs`**)\n",
"\n",
Expand Down Expand Up @@ -1610,7 +1643,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.show_training_loop\" class=\"doc_header\"><code>Learner.show_training_loop</code><a href=\"__main__.py#L207\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.show_training_loop\" class=\"doc_header\"><code>Learner.show_training_loop</code><a href=\"__main__.py#L223\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.show_training_loop</code>()\n",
"\n",
Expand Down Expand Up @@ -1908,7 +1941,7 @@
" map_loc = 'cpu' if cpu else default_device()\n",
" try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)\n",
" except AttributeError as e: \n",
" e.args = [f\"Custom classes or functions exported with your `Learner` are not available in the namespace currently.\\nPlease re-declare or import them before calling `load_learner`:\\n\\t{e.args[0]}\"]\n",
" e.args = [f\"Custom classes or functions exported with your `Learner` not available in namespace.\\Re-declare/import before loading:\\n\\t{e.args[0]}\"]\n",
" raise\n",
" if cpu: \n",
" res.dls.cpu()\n",
Expand Down Expand Up @@ -1939,7 +1972,7 @@
{
"data": {
"text/markdown": [
"<h4 id=\"Learner.to_detach\" class=\"doc_header\"><code>Learner.to_detach</code><a href=\"__main__.py#L224\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"<h4 id=\"Learner.to_detach\" class=\"doc_header\"><code>Learner.to_detach</code><a href=\"__main__.py#L240\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
"\n",
"> <code>Learner.to_detach</code>(**`b`**, **`cpu`**=*`True`*, **`gather`**=*`True`*)\n",
"\n",
Expand Down Expand Up @@ -3668,7 +3701,28 @@
"Converted 50_tutorial.datablock.ipynb.\n",
"Converted 60_medical.imaging.ipynb.\n",
"Converted 61_tutorial.medical_imaging.ipynb.\n",
"Converted 65_medical.text.ipynb.\n"
"Converted 65_medical.text.ipynb.\n",
"Converted 70_callback.wandb.ipynb.\n",
"Converted 70a_callback.tensorboard.ipynb.\n",
"Converted 70b_callback.neptune.ipynb.\n",
"Converted 70c_callback.captum.ipynb.\n",
"Converted 70d_callback.comet.ipynb.\n",
"Converted 74_huggingface.ipynb.\n",
"Converted 97_test_utils.ipynb.\n",
"Converted 99_pytorch_doc.ipynb.\n",
"Converted dev-setup.ipynb.\n",
"Converted app_examples.ipynb.\n",
"Converted camvid.ipynb.\n",
"Converted distributed_app_examples.ipynb.\n",
"Converted migrating_catalyst.ipynb.\n",
"Converted migrating_ignite.ipynb.\n",
"Converted migrating_lightning.ipynb.\n",
"Converted migrating_pytorch.ipynb.\n",
"Converted migrating_pytorch_verbose.ipynb.\n",
"Converted ulmfit.ipynb.\n",
"Converted index.ipynb.\n",
"Converted quick_start.ipynb.\n",
"Converted tutorial.ipynb.\n"
]
}
],
Expand Down

0 comments on commit c31cd24

Please sign in to comment.