Skip to content

Commit

Permalink
tensorflow 0.12 fixes (#4815)
Browse files Browse the repository at this point in the history
* initial tensorflow 0.12 fixes

see #4805

* fixed indents for pep8

* added tests for clipnorm and clipvalues

* updated travis to tf 0.12.1

* batch_matmul removed

even though the tests don’t fail on travis… they fail locally…

* make changes work with TF 0.11

* move statement outside of if
  • Loading branch information
kashif authored and fchollet committed Jan 11, 2017
1 parent 89f0527 commit 875bc59
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Expand Up @@ -49,9 +49,9 @@ install:

# install TensorFlow
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.11.0-cp27-none-linux_x86_64.whl;
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl;
elif [[ "$TRAVIS_PYTHON_VERSION" == "3.4" ]]; then
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.11.0-cp34-cp34m-linux_x86_64.whl;
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl;
fi
# command to run tests
script:
Expand Down
38 changes: 23 additions & 15 deletions keras/backend/tensorflow_backend.py
Expand Up @@ -867,10 +867,14 @@ def batch_dot(x, y, axes=None):
else:
adj_x = None
adj_y = None
try:
out = tf.batch_matmul(x, y, adj_a=adj_x, adj_b=adj_y)
except TypeError:
out = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
# TODO: remove later.
if hasattr(tf, 'batch_matmul'):
try:
out = tf.batch_matmul(x, y, adj_a=adj_x, adj_b=adj_y)
except TypeError:
out = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
else:
out = tf.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if ndim(out) == 1:
out = expand_dims(out, 1)
return out
Expand Down Expand Up @@ -1697,6 +1701,10 @@ def rnn(step_function, inputs, initial_states,
if constants is None:
constants = []

# TODO: remove later.
if hasattr(tf, 'select'):
tf.where = tf.select

if unroll:
if not inputs.get_shape()[0]:
raise ValueError('Unrolling requires a '
Expand All @@ -1717,7 +1725,7 @@ def rnn(step_function, inputs, initial_states,
for input, mask_t in zip(input_list, mask_list):
output, new_states = step_function(input, states + constants)

# tf.select needs its condition tensor
# tf.where needs its condition tensor
# to be the same shape as its two
# result tensors, but in our case
# the condition (mask) tensor is
Expand All @@ -1735,16 +1743,16 @@ def rnn(step_function, inputs, initial_states,
else:
prev_output = successive_outputs[-1]

output = tf.select(tiled_mask_t, output, prev_output)
output = tf.where(tiled_mask_t, output, prev_output)

return_states = []
for state, new_state in zip(states, new_states):
# (see earlier comment for tile explanation)
tiled_mask_t = tf.tile(mask_t,
stack([1, tf.shape(new_state)[1]]))
return_states.append(tf.select(tiled_mask_t,
new_state,
state))
return_states.append(tf.where(tiled_mask_t,
new_state,
state))
states = return_states
successive_outputs.append(output)
successive_states.append(states)
Expand Down Expand Up @@ -1805,8 +1813,8 @@ def _step(time, output_ta_t, *states):
new_state.set_shape(state.get_shape())
tiled_mask_t = tf.tile(mask_t,
stack([1, tf.shape(output)[1]]))
output = tf.select(tiled_mask_t, output, states[0])
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
output = tf.where(tiled_mask_t, output, states[0])
new_states = [tf.where(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
else:
Expand Down Expand Up @@ -1939,7 +1947,7 @@ def elu(x, alpha=1.):
if alpha == 1:
return res
else:
return tf.select(x > 0, res, alpha * res)
return tf.where(x > 0, res, alpha * res)


def softmax(x):
Expand Down Expand Up @@ -2426,9 +2434,9 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
dtype = floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))
return tf.where(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))


# CTC
Expand Down
4 changes: 3 additions & 1 deletion keras/optimizers.py
@@ -1,7 +1,9 @@
from __future__ import absolute_import

from six.moves import zip

from . import backend as K
from .utils.generic_utils import get_from_module
from six.moves import zip


def clip_norm(g, c, n):
Expand Down
10 changes: 10 additions & 0 deletions tests/keras/test_optimizers.py
Expand Up @@ -74,5 +74,15 @@ def test_nadam():
_test_optimizer(Nadam())


def test_clipnorm():
sgd = SGD(lr=0.01, momentum=0.9, clipnorm=0.5)
_test_optimizer(sgd)


def test_clipvalue():
sgd = SGD(lr=0.01, momentum=0.9, clipvalue=0.5)
_test_optimizer(sgd)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 875bc59

Please sign in to comment.