Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easy support for per-module train/eval state #37

Closed
AlexeyKurakin opened this issue Sep 9, 2020 · 5 comments
Closed

Easy support for per-module train/eval state #37

AlexeyKurakin opened this issue Sep 9, 2020 · 5 comments
Assignees
Labels
feature request New feature or request

Comments

@AlexeyKurakin
Copy link
Member

This is tracking issue to improve support of per-module train/eval state in Objax.
This issue originates from discussion of PyTorch vs Objax-style of propagating training/eval mode #29

PyTorch style allows easy way to specify per-module train/eval mode, like in the following example:

model = Resnet50(nclasses=1000)
…
# Here is example how to set most of the network,
# except few modules into training mode
model.train()
model.block_1.bn_1.eval()
model.block_2.bn_2.eval()

There is no clean way to achieve the same thing in Objax right now. One possibility is to use functools.partial:

model = Resnet50(nclasses=1000)
…
# Here is example how to force certain batch norms into eval mode
model.block_1.bn_1 = functools.partial(model.block_1.bn_1, training=False)
model.block_2.bn_2 = functools.partial(model.block_2.bn_2, training=False)

# following line will call model in training mode, except for two block_1.bn_1 and block_2.bn_2
y = model(x, training=True)

However there are some problems with functools.partial:

  • it converts everything into a function (thus vars() are not propagated).
  • if bn_eval = functools.partial(bn, training=False) and caller will try to pass training argument to bn_eval it will cause run-time error
  • there is no easy way to undo functools.partial after it applied to a module

Thus we need a better solution to do per-module train/eval state

@AlexeyKurakin
Copy link
Member Author

AlexeyKurakin commented Sep 23, 2020

Here is one idea of how this could be implemented:

######### in module.py

class ArgsOverride(Module):

    def __init__(self, base_module, **kwargs):
        self.base_module = base_module
        self.kwargs = kwargs
  
    def vars(self, scope):
        # presence of ArgsOverride module won’t affect variable names
        return self.base_module.vars(scope)

    def __call__(self, *args, **kwargs):
        kwargs.update(self.kwargs)  # in practice this should only
                                    # kwargs which are presented in
                                    # signature of base_module
        self.base_module(*args, **kwargs)

######### in utils.py

def reset_args_override(module):
  # removes ArgsOverride from module and all submodules############ in user code

model = Resnet50(nclasses=1000)
…
# Here is example how to set most of the network,
# except few modules into training mode
model.block_1.bn_1 = objax.ArgsOverride(model.block_1.bn_1, training=False)
model.block_2.bn_2 = objax.ArgsOverride(model.block_2.bn_2, training=False)

# reset all args override
objax.utils.reset_args_override(model)

# set override on different sub-module
model.block_3.bn_3 = objax.ArgsOverride(model.block_3.bn_3, training=False)

@rwightman do you have any feedback about this one?

@rwightman
Copy link

rwightman commented Sep 23, 2020

@AlexeyKurakin that could work... as implemented ArgsOverride it's quite generic and could be used for any arg. Is there any other arg passed throught the __call__ chain that you think one would want to override? If not, something like ForceNotTraining(Module) ForceTraining(Module) without the need to specifcy kwargs would be a bit more clear.

In the absence of additional functionality here, I was likely going to go a subclassing / alternate module impl route ... basically create FrozenBatchNorm, EvalBatchNorm, EvalDropout style classes and helpers to walk module hierarchy within subset of model and switch class types (and copy state). But then that'd run into the checkpoint compat issues discussed.

@AlexeyKurakin
Copy link
Member Author

Right now I can't really think of other arguments in addition to training, though I guess it might be convenient to do so for some non-standard modules.

Also ArgsOverride opens syntax like following:

# without ArgsOverride
predict = objax.Jit(lambda x: model(x, training=False), model.vars())

# with ArgsOverride
predict = objax.Jit(objax.ArgsOverride(model, training=False))

@david-berthelot david-berthelot added the feature request New feature or request label Sep 24, 2020
@AlexeyKurakin
Copy link
Member Author

@rwightman Eventually we decided to use name ForceArgs for this feature and as I mentioned above it's somewhat more generic than simply forcing training flag. Right now change is merged into repository and available to be used.

@rwightman
Copy link

@AlexeyKurakin thanks for the heads up, looks good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants