Skip to content

Commit

Permalink
[Refactoring] Removed code duplication in the theano backend. (#11131)
Browse files Browse the repository at this point in the history
* Removed code duplication in the theano backend.

* Changed the name of the function to _set_keras_shape_for_reduction.
  • Loading branch information
gabrieldemarmiesse authored and fchollet committed Sep 13, 2018
1 parent 06c3a80 commit ba7ab2f
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,31 +589,19 @@ def any(x, axis=None, keepdims=False):
"""Bitwise reduction (logical OR).
"""
y = T.any(x, axis=axis, keepdims=keepdims)
if hasattr(x, '_keras_shape'):
if axis is None:
y._keras_shape = (1,) * len(x._keras_shape) if keepdims else (1,)
else:
if isinstance(axis, int):
axis_list = [axis]
else:
axis_list = list(set(int(a) for a in axis))
keras_shape_list = list(x._keras_shape)
if keepdims:
for a in axis_list:
keras_shape_list[a] = 1
else:
for a in axis_list[::-1]:
keras_shape_list.pop(a)
if not keras_shape_list:
keras_shape_list = (1,)
y._keras_shape = tuple(keras_shape_list)
y = _set_keras_shape_for_reduction(x, y, axis, keepdims)
return y


def all(x, axis=None, keepdims=False):
"""Bitwise reduction (logical AND).
"""
y = T.all(x, axis=axis, keepdims=keepdims)
y = _set_keras_shape_for_reduction(x, y, axis, keepdims)
return y


def _set_keras_shape_for_reduction(x, y, axis, keepdims):
if hasattr(x, '_keras_shape'):
if axis is None:
y._keras_shape = (1,) * len(x._keras_shape) if keepdims else (1,)
Expand Down

0 comments on commit ba7ab2f

Please sign in to comment.