Skip to content

Commit

Permalink
Merge pull request #121 from bartvm/pickle_bug
Browse files Browse the repository at this point in the history
Get rid of class factory for activations
  • Loading branch information
bartvm committed Jan 20, 2015
2 parents eae3f54 + c68ce6c commit 4ff20da
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 62 deletions.
118 changes: 60 additions & 58 deletions blocks/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,65 +1026,67 @@ def apply(self, input_):
return output


class _PicklableActivation(object):
"""A base class for dynamically generated classes that can be pickled."""
def __reduce__(self):
activation = self.__class__._activation
if hasattr(activation, '__func__'):
activation = activation.__func__
return (_Initializor(),
(self.__class__.__name__, activation),
self.__dict__)


class _Initializor(object):
"""A callable object which returns a parametrized class."""
def __call__(self, name, activation):
object_ = _Initializor()
object_.__class__ = _activation_factory(name, activation)
return object_


def _activation_factory(name, activation):
"""Class factory for Bricks which perform simple Theano calls."""
class ActivationDocumentation(type):
def __new__(cls, name, bases, classdict):
classdict['__doc__'] = classdict['__doc__'].format(name.lower())
class ActivationDocumentation(type):
def __new__(cls, name, bases, classdict):
classdict['__doc__'] = \
"""Elementwise application of {0} function.""".format(name.lower())
if 'apply' in classdict:
classdict['apply'].__doc__ = \
classdict['apply'].__doc__.format(name.lower())
return type.__new__(cls, name, bases, classdict)

@add_metaclass(ActivationDocumentation)
class Activation(Brick, _PicklableActivation):
"""Element-wise application of {0} function."""
_activation = activation

@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
"""Apply the {0} function element-wise.
Parameters
----------
input_ : Theano variable
Theano variable to apply {0} to, element-wise.
Returns
-------
output : Theano variable
The input with the activation function applied.
"""
output = activation(input_)
return output
Activation.__name__ = name
return Activation

Identity = _activation_factory('Identity', lambda x: x)
Tanh = _activation_factory('Tanh', tensor.tanh)
Sigmoid = _activation_factory('Sigmoid', tensor.nnet.sigmoid)
Softmax = _activation_factory('Softmax', tensor.nnet.softmax)
Rectifier = _activation_factory('Rectifier',
lambda x: tensor.switch(x > 0, x, 0))
"""Apply the {0} function elementwise.
Parameters
----------
input_ : Theano variable
Theano variable to apply {0} to, elementwise.
Returns
-------
output : Theano variable
The input with the activation function applied.
""".format(name.lower())
return type.__new__(cls, name, bases, classdict)


@add_metaclass(ActivationDocumentation)
class Activation(Brick):
"""A base class for simple, elementwise activation functions.
This base class ensures that activation functions are automatically
documented using the :class:`ActivationDocumentation` metaclass.
"""
pass


class Identity(Activation):
@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
return input_


class Tanh(Activation):
@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
return tensor.tanh(input_)


class Sigmoid(Activation):
@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
return tensor.nnet.sigmoid(input_)


class Rectifier(Activation):
@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
return tensor.switch(input_ > 0, input_, 0)


class Softmax(Activation):
@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
return tensor.nnet.softmax(input_)


class Sequence(Brick):
Expand Down
5 changes: 1 addition & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def test_mnist():

def test_pylearn2():
filename = 'unittest_markov_chain'
try:
pylearn2_test('train', filename, 0, 3, False)
except (dill.PicklingError, OSError):
pass
pylearn2_test('train', filename, 0, 3, False)
os.remove(filename)

test_pylearn2.setup = setup

0 comments on commit 4ff20da

Please sign in to comment.