Skip to content

Commit

Permalink
Fix StepRules to honour broadcastable.
Browse files Browse the repository at this point in the history
  • Loading branch information
dwf committed Jan 22, 2016
1 parent d9171ff commit d129aa7
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 13 deletions.
24 changes: 12 additions & 12 deletions blocks/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from blocks.graph import ComputationGraph
from blocks.roles import add_role, ALGORITHM_HYPERPARAMETER, ALGORITHM_BUFFER
from blocks.theano_expressions import l2_norm
from blocks.utils import dict_subset, pack, shared_floatx
from blocks.utils import (dict_subset, pack, shared_floatx,
shared_floatx_zeros_matching)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -420,7 +421,7 @@ def __init__(self, momentum=0.):
add_role(self.momentum, ALGORITHM_HYPERPARAMETER)

def compute_step(self, parameter, previous_step):
velocity = shared_floatx(parameter.get_value() * 0., "velocity")
velocity = shared_floatx_zeros_matching(parameter, "velocity")
add_role(velocity, ALGORITHM_BUFFER)
step = self.momentum * velocity + previous_step
updates = [(velocity, step)]
Expand Down Expand Up @@ -487,11 +488,11 @@ def __init__(self, decay_rate=0.95, epsilon=1e-6):
add_role(self.epsilon, ALGORITHM_HYPERPARAMETER)

def compute_step(self, parameter, previous_step):
mean_square_step_tm1 = shared_floatx(parameter.get_value() * 0.,
"mean_square_step_tm1")
mean_square_step_tm1 = shared_floatx_zeros_matching(
parameter, "mean_square_step_tm1")
add_role(mean_square_step_tm1, ALGORITHM_BUFFER)
mean_square_delta_x_tm1 = shared_floatx(parameter.get_value() * 0.,
"mean_square_delta_x_tm1")
mean_square_delta_x_tm1 = shared_floatx_zeros_matching(
parameter, "mean_square_delta_x_tm1")
add_role(mean_square_delta_x_tm1, ALGORITHM_BUFFER)

mean_square_step_t = (
Expand Down Expand Up @@ -550,8 +551,8 @@ def __init__(self, decay_rate=0.9, max_scaling=1e5):
self.epsilon = 1. / max_scaling

def compute_step(self, parameter, previous_step):
mean_square_step_tm1 = shared_floatx(parameter.get_value() * 0.,
"mean_square_step_tm1")
mean_square_step_tm1 = shared_floatx_zeros_matching(
parameter, "mean_square_step_tm1")
add_role(mean_square_step_tm1, ALGORITHM_BUFFER)
mean_square_step_t = (
self.decay_rate * mean_square_step_tm1 +
Expand Down Expand Up @@ -749,8 +750,7 @@ def compute_step(self, parameter, previous_step):
name = 'adagrad_sqs'
if parameter.name:
name += '_' + parameter.name
ssq = shared_floatx(parameter.get_value() * 0.,
name=name)
ssq = shared_floatx_zeros_matching(parameter, name=name)
add_role(ssq, ALGORITHM_BUFFER)

ssq_t = (tensor.sqr(previous_step) + ssq)
Expand Down Expand Up @@ -796,9 +796,9 @@ def __init__(self, learning_rate=0.002,
self.decay_factor = decay_factor

def compute_step(self, parameter, previous_step):
mean = shared_floatx(parameter.get_value() * 0., 'mean')
mean = shared_floatx_zeros_matching(parameter, 'mean')
add_role(mean, ALGORITHM_BUFFER)
variance = shared_floatx(parameter.get_value() * 0., 'variance')
variance = shared_floatx_zeros_matching(parameter, 'variance')
add_role(variance, ALGORITHM_BUFFER)
time = shared_floatx(0., 'time')
add_role(time, ALGORITHM_BUFFER)
Expand Down
29 changes: 29 additions & 0 deletions blocks/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,35 @@ def unpack(arg, singleton=False):
return arg


def shared_floatx_zeros_matching(shared_variable, name=None, **kwargs):
"""Create another shared variable with matching shape and broadcast.
Parameters
----------
shared_variable : :class:'tensor.TensorSharedVariable'
A Theano shared variable with the desired shape and broadcastable
flags.
name : :obj:`str`, optional
The name for the shared variable. Defaults to `None`.
Returns
-------
:class:'tensor.TensorSharedVariable'
A new shared variable, initialized to all zeros, with the same
shape and broadcastable flags as `shared_variable`.
\*\*kwargs
Keyword arguments to pass to the :func:`shared_floatx_zeros` function.
"""
if not is_shared_variable(shared_variable):
raise ValueError('argument must be a shared variable')
return shared_floatx_zeros(shared_variable.get_value().shape,
name=name,
broadcastable=shared_variable.broadcastable,
**kwargs)


def shared_floatx_zeros(shape, **kwargs):
r"""Creates a shared variable array filled with zeros.
Expand Down
59 changes: 58 additions & 1 deletion tests/algorithms/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,27 @@
CompositeRule, Scale, StepRule, BasicMomentum,
Momentum, AdaDelta, BasicRMSProp, RMSProp, Adam,
AdaGrad, RemoveNotFinite, Restrict)
from blocks.utils import shared_floatx
from blocks.utils import shared_floatx, shared_floatx_zeros


def verify_broadcastable_handling(step_rule):
def check(param):
grad = tensor.grad(param.sum(), wrt=param)
step, _ = step_rule.compute_steps(OrderedDict([(param, grad)]))
assert step[param].broadcastable == grad.broadcastable

check(shared_floatx_zeros((5, 6, 1, 5),
broadcastable=(False, False, True, False)))
check(shared_floatx_zeros((2, 1, 3),
broadcastable=(False, True, False)))
check(shared_floatx_zeros((3, 4, 1),
broadcastable=(False, False, True)))
check(shared_floatx_zeros((1, 9, 6),
broadcastable=(True, False, False)))
check(shared_floatx_zeros((1, 1, 1),
broadcastable=(True, True, True)))
check(shared_floatx_zeros((1, 5, 1),
broadcastable=(True, False, True)))


def test_gradient_descent():
Expand Down Expand Up @@ -69,6 +89,10 @@ def test_basic_momentum():
assert_allclose(f()[0], [10.5, 14.])


def test_basic_momentum_broadcastable():
verify_broadcastable_handling(BasicMomentum(0.5))


def test_momentum():
a = shared_floatx([3, 4])
cost = (a ** 2).sum()
Expand All @@ -80,6 +104,10 @@ def test_momentum():
assert_allclose(f()[0], [1.05, 1.4])


def test_momentum_broadcastable():
verify_broadcastable_handling(Momentum(0.5))


def test_adadelta():
a = shared_floatx([3, 4])
cost = (a ** 2).sum()
Expand Down Expand Up @@ -110,6 +138,10 @@ def test_basicrmsprop():
assert_allclose(f()[0], [0.6172134, 0.64699664])


def test_basicrmsprop_broadcastable():
verify_broadcastable_handling(BasicRMSProp(0.5, 1e5))


def test_basicrmsprop_max_scaling():
a = shared_floatx([1e-6, 1e-6])
cost = (a ** 2).sum()
Expand Down Expand Up @@ -143,6 +175,10 @@ def test_rmsprop():
assert_allclose(f()[0], [0.06172134, 0.064699664])


def test_rmsprop_broadcastable():
verify_broadcastable_handling(RMSProp(0.1, 0.5, 1e5))


def test_step_clipping():
rule1 = StepClipping(4)
rule2 = StepClipping(5)
Expand All @@ -156,6 +192,10 @@ def test_step_clipping():
assert_allclose(clipped2[1].eval(), 4.0)


def test_step_clipping_broadcastable():
verify_broadcastable_handling(StepClipping(0.4))


def test_variable_clipping():
# Test simple variable clipping with no axis.
rule1 = VariableClipping(5)
Expand Down Expand Up @@ -208,6 +248,10 @@ def test_variable_clipping():
assert_raises(ValueError, VariableClipping, 50, axis=(0, 0))


def test_variable_clipping_broadcastable():
verify_broadcastable_handling(VariableClipping(1))


def test_composite_rule():
rule = CompositeRule([StepClipping(4), Scale(0.1)])
gradients = {0: shared_floatx(3.0), 1: shared_floatx(4.0)}
Expand Down Expand Up @@ -242,6 +286,10 @@ def test_adam():
assert_allclose(f()[0], [0.00178724, 0.0018223], rtol=rtol)


def test_adam_broadcastable():
verify_broadcastable_handling(Adam())


def test_adagrad():
a = shared_floatx([3, 4])
cost = (a ** 2).sum()
Expand All @@ -257,6 +305,10 @@ def test_adagrad():
assert_allclose(f()[0], [0.00053452, 0.0005747], rtol=rtol)


def test_adagrad_broadcastable():
verify_broadcastable_handling(AdaGrad())


def test_remove_not_finite():
rule1 = RemoveNotFinite(0.1)
rule2 = RemoveNotFinite()
Expand All @@ -272,6 +324,11 @@ def test_remove_not_finite():
assert_allclose(rval2[2].eval(), 0.0)


def test_remove_not_finite_broadcastable():
verify_broadcastable_handling(RemoveNotFinite())
verify_broadcastable_handling(RemoveNotFinite(0.1))


class DummyUpdatesStepRule(StepRule):
def compute_step(self, parameter, previous_step):
return previous_step + 2, [(parameter * 10, parameter * 100)]
Expand Down

0 comments on commit d129aa7

Please sign in to comment.