Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- new callback to maintain parameter EMAs (not tested yet)
Browse files Browse the repository at this point in the history
- minor update for callback instance registry
  • Loading branch information
nasimrahaman committed Sep 9, 2017
1 parent e253fcd commit ed610d8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
2 changes: 1 addition & 1 deletion inferno/trainers/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self):

@classmethod
def register_instance(cls, instance):
if hasattr(cls, '_instance_registry'):
if hasattr(cls, '_instance_registry') and instance not in cls._instance_registry:
cls._instance_registry.append(instance)
else:
cls._instance_registry = [instance]
Expand Down
33 changes: 33 additions & 0 deletions inferno/trainers/callbacks/essentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,36 @@ def end_of_validation_run(self, **_):
.format(self._ema_validation_score,
self._best_ema_validation_score))
# Done


class ParameterEMA(Callback):
"""Maintain a moving average of network parameters."""
def __init__(self, momentum):
"""
Parameters
----------
momentum : float
Momentum for the moving average. The following holds:
`new_moving_average = momentum * old_moving_average + (1 - momentum) * value`
"""
super(ParameterEMA, self).__init__()
# Privates
self._parameters = None
# Publics
self.momentum = momentum

def maintain(self):
if self._parameters is None:
self._parameters = [p.data.new().zero_() for p in self.trainer.model.parameters()]
for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters):
p_ema.mul_(self.momentum).add_(p_model.data.mul(1. - self.momentum))

def apply(self):
assert_(self._parameters is not None,
"Can't apply parameter EMA's: not available.",
ValueError)
for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters):
p_model.data.copy_(p_ema)

def end_of_training_iteration(self, **_):
self.maintain()

0 comments on commit ed610d8

Please sign in to comment.