Skip to content

Commit

Permalink
Merge pull request #1075 from mila-udem/override_sharedvariablemodifier
Browse files Browse the repository at this point in the history
Allow override pattern for SharedVariableModifier.
  • Loading branch information
vdumoulin committed May 4, 2016
2 parents 3559b34 + cfc9c0c commit 41dd4af
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions blocks/extensions/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,38 @@ class SharedVariableModifier(SimpleExtension):
In the second case, it is a function which takes number of
iterations done (``int``) and old value of the shared variable
(with the same dtype as `parameter`).
num_args : int, optional
The number of arguments to pass to the function. If unspecified,
it will be inferred. This is useful if you are using function-like
objects for which the arity of the function cannot be inferred.
Notes
-----
This class includes a method ``function`` that calls the function
passed in the constructor and a ``num_args`` property which computes
the number of arguments to use by inspecting the function object.
Subclasses may override a method called ``function`` and/or
the ``num_args`` property and instead pass ``None`` to the superclass
constructor. This can be used to bypass certain serialization issues
on Legacy Python regarding the unpicklability of instance
method objects.
"""
def __init__(self, parameter, function, **kwargs):
def __init__(self, parameter, function, num_args=None, **kwargs):
kwargs.setdefault("after_batch", True)
super(SharedVariableModifier, self).__init__(**kwargs)
self.parameter = parameter
self.function = function
self.num_args = len(inspect.getargspec(function).args)
self._function = function
self._num_args = num_args

@property
def num_args(self):
if self._num_args is None:
self._num_args = len(inspect.getargspec(self._function).args)
return self._num_args

def function(self, *args):
return self._function(*args)

def do(self, which_callback, *args):
iterations_done = self.main_loop.log.status['iterations_done']
Expand Down

0 comments on commit 41dd4af

Please sign in to comment.