Skip to content

Commit

Permalink
Merge pull request #1101 from dwf/squared_l2_norm
Browse files Browse the repository at this point in the history
Bug fix + enhancement to l2_norm.
  • Loading branch information
dmitriy-serdyuk committed Jun 1, 2016
2 parents 449a80d + 1bc1fc1 commit 387d159
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 5 additions & 3 deletions blocks/theano_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from theano import tensor


def l2_norm(tensors):
def l2_norm(tensors, squared=False):
"""Computes the total L2 norm of a set of tensors.
Converts all operands to :class:`~tensor.TensorVariable`
Expand All @@ -12,11 +12,13 @@ def l2_norm(tensors):
----------
tensors : iterable of :class:`~tensor.TensorVariable` (or compatible)
The tensors.
squared : bool, optional
If `True`, return the squared L2 norm. Default: `False`.
"""
summed = [tensor.sqr(tensor.as_tensor_variable(t)).sum() for t in tensors]
joined = tensor.stack(*summed)
return tensor.sqrt(joined.sum())
joined = tensor.stack(summed, axis=0)
return joined.sum() if squared else tensor.sqrt(joined.sum())


def hessian_times_vector(gradient, parameter, vector, r_op=False):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_theano_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def test_l2_norm():
assert_allclose(l2_norm([3, [1, 2]]).eval(), 14.0 ** 0.5)
assert_allclose(
l2_norm([3, [1, 2], [[1, 2], [3, 4]]]).eval(), 44.0 ** 0.5)
assert_allclose(
l2_norm([3, [1, 2], [[1, 2], [3, 4]]], squared=True).eval(), 44.0)


def test_hessian_times_vector():
Expand Down

0 comments on commit 387d159

Please sign in to comment.