Skip to content

Commit

Permalink
fixes #2931 #2920
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Nov 3, 2020
1 parent 91df5a3 commit 12e7977
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Untitled*.ipynb
*.bak
token
.idea/
Expand Down
2 changes: 1 addition & 1 deletion fastai/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
"Mish": "01_layers.ipynb",
"ParameterModule": "01_layers.ipynb",
"children_and_parameters": "01_layers.ipynb",
"nn.Module.has_children": "01_layers.ipynb",
"has_children": "01_layers.ipynb",
"flatten_model": "01_layers.ipynb",
"NoneReduce": "01_layers.ipynb",
"in_channels": "01_layers.ipynb",
Expand Down
10 changes: 4 additions & 6 deletions fastai/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
'AvgPool', 'trunc_normal_', 'Embedding', 'SelfAttention', 'PooledSelfAttention2d', 'SimpleSelfAttention',
'icnr_init', 'PixelShuffle_ICNR', 'sequential', 'SequentialEx', 'MergeLayer', 'Cat', 'SimpleCNN',
'ProdLayer', 'inplace_relu', 'SEModule', 'ResBlock', 'SEBlock', 'SEResNeXtBlock', 'SeparableBlock', 'swish',
'Swish', 'MishJitAutoFn', 'mish', 'Mish', 'ParameterModule', 'children_and_parameters', 'flatten_model',
'NoneReduce', 'in_channels']
'Swish', 'MishJitAutoFn', 'mish', 'Mish', 'ParameterModule', 'children_and_parameters', 'has_children',
'flatten_model', 'NoneReduce', 'in_channels']

# Cell
from .imports import *
Expand Down Expand Up @@ -568,17 +568,15 @@ def children_and_parameters(m):
return children

# Cell
def _has_children(m:nn.Module):
def has_children(m):
try: next(m.children())
except StopIteration: return False
return True

nn.Module.has_children = property(_has_children)

# Cell
def flatten_model(m):
"Return the list of all submodules and parameters of `m`"
return sum(map(flatten_model,children_and_parameters(m)),[]) if m.has_children else [m]
return sum(map(flatten_model,children_and_parameters(m)),[]) if has_children(m) else [m]

# Cell
class NoneReduce():
Expand Down
16 changes: 6 additions & 10 deletions nbs/01_layers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@
]
},
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"output_type": "execute_result"
}
],
Expand Down Expand Up @@ -1884,12 +1882,10 @@
"outputs": [],
"source": [
"#export\n",
"def _has_children(m:nn.Module):\n",
"def has_children(m):\n",
" try: next(m.children())\n",
" except StopIteration: return False\n",
" return True\n",
"\n",
"nn.Module.has_children = property(_has_children)"
" return True"
]
},
{
Expand All @@ -1899,8 +1895,8 @@
"outputs": [],
"source": [
"class A(Module): pass\n",
"assert not A().has_children\n",
"assert TstModule().has_children"
"assert not has_children(A())\n",
"assert has_children(TstModule())"
]
},
{
Expand All @@ -1912,7 +1908,7 @@
"# export\n",
"def flatten_model(m):\n",
" \"Return the list of all submodules and parameters of `m`\"\n",
" return sum(map(flatten_model,children_and_parameters(m)),[]) if m.has_children else [m]"
" return sum(map(flatten_model,children_and_parameters(m)),[]) if has_children(m) else [m]"
]
},
{
Expand Down

0 comments on commit 12e7977

Please sign in to comment.