Skip to content

Commit

Permalink
Merge pull request #1146 from dwf/step_clipping_corner_case
Browse files Browse the repository at this point in the history
Fix StepClipping return when threshold=None.
  • Loading branch information
dwf committed Sep 12, 2016
2 parents e1fedb0 + e757ced commit b93253d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
24 changes: 13 additions & 11 deletions blocks/algorithms/__init__.py
Expand Up @@ -696,19 +696,21 @@ class StepClipping(StepRule):
"""
def __init__(self, threshold=None):
if threshold:
self.threshold = shared_floatx(threshold, "threshold")
add_role(self.threshold, ALGORITHM_HYPERPARAMETER)
if threshold is not None:
threshold = shared_floatx(threshold, "threshold")
add_role(threshold, ALGORITHM_HYPERPARAMETER)
self.threshold = threshold

def compute_steps(self, previous_steps):
if not hasattr(self, 'threshold'):
return previous_steps
norm = l2_norm(previous_steps.values())
multiplier = tensor.switch(norm < self.threshold,
1, self.threshold / norm)
steps = OrderedDict(
(parameter, step * multiplier)
for parameter, step in previous_steps.items())
if self.threshold is None:
steps = previous_steps
else:
norm = l2_norm(previous_steps.values())
multiplier = tensor.switch(norm < self.threshold,
1, self.threshold / norm)
steps = OrderedDict(
(parameter, step * multiplier)
for parameter, step in previous_steps.items())
return steps, []


Expand Down
10 changes: 10 additions & 0 deletions tests/algorithms/test_algorithms.py
Expand Up @@ -311,6 +311,16 @@ def test_step_clipping():
assert_allclose(clipped2[1].eval(), 4.0)


def test_step_clipping_no_threshold_regression():
"""Test regression for #1145, incorrect output when threshold=None."""
rule1 = StepClipping()
assert rule1.threshold is None
gradients = {0: shared_floatx(3.0), 1: shared_floatx(4.0)}
clipped1, updates = rule1.compute_steps(gradients)
assert len(updates) == 0
assert clipped1 == gradients


def test_step_clipping_broadcastable():
verify_broadcastable_handling(StepClipping(0.4))

Expand Down

0 comments on commit b93253d

Please sign in to comment.