Skip to content

Commit

Permalink
fix block.export (apache#17970)
Browse files Browse the repository at this point in the history
* fix block.export

```net.hybridize``` may optimize out some ops. These ops are alive in nn.Block(also nn.HybridBlock), but its names are not contained in symbol's ```arg_names``` list. So ignore these ops except that their name are end with 'running_mean' or 'running_var'.

* Update block.py

let user can save their extra param.

* add allow_extra

add allow_extra to let user decide whether to save extra parameters or not.

* Update block.py

add moving_mean and moving_var when export model with SymbolBlock

* Update python/mxnet/gluon/block.py

typo

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>

* Update block.py

* Update block.py

* Update python/mxnet/gluon/block.py

Co-authored-by: Leonard Lausen <leonard@lausen.nl>

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>
Co-authored-by: Leonard Lausen <leonard@lausen.nl>
  • Loading branch information
3 people authored and James Mracek committed Sep 2, 2020
1 parent 0e86a91 commit 9a81a40
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions python/mxnet/gluon/block.py
Expand Up @@ -1195,12 +1195,16 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for name, param in self.collect_params().items():
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()
for is_arg, name, param in self._cached_op_args:
if not is_arg:
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
if name not in aux_names:
warnings.warn('Parameter "{name}" is not found in the graph. '
.format(name=name), stacklevel=3)
else:
arg_dict['aux:%s'%name] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn('%s-%04d.params'%(path, epoch), arg_dict)

Expand Down

0 comments on commit 9a81a40

Please sign in to comment.