Skip to content

Commit

Permalink
Merge pull request #2 from Lissanro/bottleneck
Browse files Browse the repository at this point in the history
Search for both basic and bottleneck blocks (to fix "no known network structure detected" warning with ResNet-50 and other similar models)
  • Loading branch information
nestordemeure committed Jun 8, 2020
2 parents bef180f + 85fa869 commit 2c807aa
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ See the [Annealing](http://dev.fast.ai/callback.schedule#Annealing) section of f
`ManifoldMixup` tries to establish a sensible list of modules on which to apply mixup:
- it uses a user provided `module_list` if possible
- otherwise it uses only the modules wrapped with `ManifoldMixupModule`
- if none are found, it defaults to modules with `Block` in their name (targetting mostly resblocks)
- if none are found, it defaults to modules with `Block` or `Bottleneck` in their name (targetting mostly resblocks)
- finaly, if needed, it defaults to all modules that are not included in the `non_mixable_module_types` list

The `non_mixable_module_types` list contains mostly recurrent layers but you can add elements to it in order to define module classes that should not be used for mixup (*do not hesitate to create an issue or start a PR to add common modules to the default list*).
Expand Down
5 changes: 3 additions & 2 deletions manifold_mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def _is_mixable(m):
return not any(isinstance(m, non_mixable_class) for non_mixable_class in non_mixable_module_types)

def _is_block_module(m):
"Checks wether a module is a Block (typically a kind of resBlock)"
return "block" in str(type(m)).lower()
"Checks whether a module is a Block or Bottleneck (typically a kind of resBlock)"
m = str(type(m)).lower()
return "block" in m or "bottleneck" in m

def _get_mixup_module_list(model):
"returns all the modules that can be used for mixup"
Expand Down

0 comments on commit 2c807aa

Please sign in to comment.