Skip to content

Commit

Permalink
Update cntk backend with CNTK 2.2 release (#7907)
Browse files Browse the repository at this point in the history
* integrate with cntk native batch axis convert api

* support mask in recurrent layer

* update interface name

* add support for case static axis is in front of batch axis

* add reverse support

* using native padding api

* update ci test with cntk 2.1

* fix style issue

* use one hot workaround

* api change

* fix style issue

* add comments

* fix bugs in merge

* add output rank to one hot approach

* use cntk padding

* use native batch reshape

* fix the reversed rnn bug

* update padding interface

* update reshape to support free dimension

* workaround free dimension in squeeze

* enable tests with non-specified shape

* also check free dimension

* remove useless padding code

* fix a free dimension bug

* add backward compatible check

* fix pep8 issue

* update travis with cntk 2.2
  • Loading branch information
souptc authored and fchollet committed Sep 16, 2017
1 parent b24f444 commit 5938501
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 112 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ install:

# install cntk
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.1-cp27-cp27mu-linux_x86_64.whl;
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.2-cp27-cp27mu-linux_x86_64.whl;
elif [[ "$TRAVIS_PYTHON_VERSION" == "3.5" ]]; then
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.1-cp35-cp35m-linux_x86_64.whl;
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.2-cp35-cp35m-linux_x86_64.whl;
fi

# install pydot for visualization tests
Expand Down
126 changes: 91 additions & 35 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

C.set_global_option('align_axis', 1)


b_any = any


Expand Down Expand Up @@ -243,7 +242,8 @@ def placeholder(
if ndim:
shape = tuple([None for _ in range(ndim)])

cntk_shape = [C.InferredDimension if s is None else s for s in shape]
dynamic_dimension = C.FreeDimension if _get_cntk_version() >= 2.2 else C.InferredDimension
cntk_shape = [dynamic_dimension if s is None else s for s in shape]
cntk_shape = tuple(cntk_shape)

if dynamic_axis_num > len(cntk_shape):
Expand Down Expand Up @@ -563,9 +563,12 @@ def gather(reference, indices):
# There is a bug in cntk gather op which may cause crash.
# We have made a fix but not catched in CNTK 2.1 release.
# Will udpate with gather op in next release
num_class = reference.shape[0]
one_hot_matrix = C.ops.one_hot(indices, num_class)
return C.times(one_hot_matrix, reference, output_rank=len(reference.shape) - 1)
if _get_cntk_version() >= 2.2:
return C.ops.gather(reference, indices)
else:
num_class = reference.shape[0]
one_hot_matrix = C.ops.one_hot(indices, num_class)
return C.times(one_hot_matrix, reference, output_rank=len(reference.shape) - 1)


def _remove_dims(x, axis, keepdims=False):
Expand Down Expand Up @@ -665,7 +668,8 @@ def squeeze(x, axis):
for _ in sorted(_axis, reverse=True):
del shape[_]

new_shape = tuple(shape[nones:])
new_shape = shape[nones:]
new_shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in new_shape])
return C.reshape(x, new_shape)


Expand Down Expand Up @@ -753,7 +757,7 @@ def _reshape_dummy_dim(x, axis):

_axis = [_ + len(shape) if _ < 0 else _ for _ in axis]

if shape.count(C.InferredDimension) > 1:
if shape.count(C.InferredDimension) > 1 or shape.count(C.FreeDimension) > 1:
result = x
for index in sorted(_axis, reverse=True):
result = C.reshape(result,
Expand All @@ -765,7 +769,7 @@ def _reshape_dummy_dim(x, axis):
for index in sorted(_axis, reverse=True):
del shape[index]

shape = tuple(shape)
shape = [C.InferredDimension if _ == C.FreeDimension else _ for _ in shape]
return C.reshape(x, shape)


Expand Down Expand Up @@ -1063,6 +1067,7 @@ def flatten(x):


def reshape(x, shape):
shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in shape])
if isinstance(x, C.variables.Parameter):
return C.reshape(x, shape)
else:
Expand All @@ -1077,7 +1082,7 @@ def reshape(x, shape):
'collapse of batch axis with inferred dimension. '
'The reshape did not take place.')
return x
return C.user_function(ReshapeBatch(x, shape[1:]))
return _reshape_batch(x, shape)
else:
# no collapse, then first need to padding the shape
if num_dynamic_axis >= len(shape):
Expand Down Expand Up @@ -1161,7 +1166,7 @@ def repeat(x, n):
# return the same x to take cntk broadcast feature
# to make the recurrent layer work.
# need to be fixed in GA.
if n is C.InferredDimension:
if n is C.InferredDimension or n is C.FreeDimension:
return x
index = 1 - _get_dynamic_axis_num(x)
if index < 0 or index > 1:
Expand Down Expand Up @@ -1748,7 +1753,7 @@ def _is_input_shape_compatible(input, placeholder):
input_shape = input.shape[num_dynamic:]
placeholder_shape = placeholder.shape
for i, p in zip(input_shape, placeholder_shape):
if i != p and p != C.InferredDimension:
if i != p and p != C.InferredDimension and p != C.FreeDimension:
return False
return True

Expand Down Expand Up @@ -1852,10 +1857,16 @@ def temporal_padding(x, padding=(1, 1)):
base_shape = x.shape
if num_dynamic_axis > 0:
assert len(base_shape) == 2
x = _padding(x, padding, 0)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[padding, (0, 0)])
else:
x = _padding(x, padding, 0)
else:
assert len(base_shape) == 3
x = _padding(x, padding, 1)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[(0, 0), padding, (0, 0)])
else:
x = _padding(x, padding, 1)
return x


Expand All @@ -1872,13 +1883,11 @@ def _padding(x, pattern, axis):
prefix_shape = tuple(prefix_shape)
x = C.splice(C.constant(value=0, shape=prefix_shape), x, axis=axis)
base_shape = x.shape

if pattern[1] > 0:
postfix_shape = list(base_shape)
postfix_shape[axis] = pattern[1]
postfix_shape = tuple(postfix_shape)
x = C.splice(x, C.constant(value=0, shape=postfix_shape), axis=axis)

return x


Expand All @@ -1896,21 +1905,33 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
if data_format == 'channels_first':
if num_dynamic_axis > 0:
assert len(base_shape) == 3
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1])])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
else:
assert len(base_shape) == 4
x = _padding(x, padding[0], 2)
x = _padding(x, padding[1], 3)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], [0, 0], list(padding[0]), list(padding[1])])
else:
x = _padding(x, padding[0], 2)
x = _padding(x, padding[1], 3)
else:
if num_dynamic_axis > 0:
assert len(base_shape) == 3
x = _padding(x, padding[0], 0)
x = _padding(x, padding[1], 1)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[list(padding[0]), list(padding[1]), [0, 0]])
else:
x = _padding(x, padding[0], 0)
x = _padding(x, padding[1], 1)
else:
assert len(base_shape) == 4
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1]), [0, 0]])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
return x


Expand All @@ -1929,25 +1950,37 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
if data_format == 'channels_first':
if num_dynamic_axis > 0:
assert len(base_shape) == 4
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
x = _padding(x, padding[2], 3)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1]), list(padding[2])])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
x = _padding(x, padding[2], 3)
else:
assert len(base_shape) == 5
x = _padding(x, padding[0], 2)
x = _padding(x, padding[1], 3)
x = _padding(x, padding[2], 4)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], [0, 0], list(padding[0]), list(padding[1]), list(padding[2])])
else:
x = _padding(x, padding[0], 2)
x = _padding(x, padding[1], 3)
x = _padding(x, padding[2], 4)
else:
if num_dynamic_axis > 0:
assert len(base_shape) == 4
x = _padding(x, padding[0], 0)
x = _padding(x, padding[1], 1)
x = _padding(x, padding[2], 2)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[list(padding[0]), list(padding[1]), list(padding[2]), [0, 0]])
else:
x = _padding(x, padding[0], 0)
x = _padding(x, padding[1], 1)
x = _padding(x, padding[2], 2)
else:
assert len(base_shape) == 5
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
x = _padding(x, padding[2], 3)
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1]), list(padding[2]), [0, 0]])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
x = _padding(x, padding[2], 3)
return x


Expand Down Expand Up @@ -2249,6 +2282,29 @@ def reverse(x, axes):
return C.slice(x, cntk_axes, begin_index, end_index, strides)


def _reshape_batch(x, shape):
# there is a bug in cntk 2.1's unpack_batch implementation
if hasattr(C, 'unpack_batch') and _get_cntk_version() >= 2.2:
const_a = C.unpack_batch(x)
const_a = C.reshape(const_a, shape)
return C.to_batch(const_a)
else:
return C.user_function(ReshapeBatch(x, shape[1:]))


def _get_cntk_version():
version = C.__version__
if version.endswith('+'):
version = version[:-1]
try:
return float(version)
except:
warnings.warn(
'CNTK backend warning: CNTK version not detected. '
'Will using CNTK 2.0 GA as default.')
return float(2.0)


class ReshapeBatch(C.ops.functions.UserFunction):
def __init__(self, input, shape, name='reshape_with_batch'):
super(ReshapeBatch, self).__init__([input], as_numpy=False, name=name)
Expand Down
11 changes: 0 additions & 11 deletions tests/keras/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,6 @@ def test_elu():
assert_allclose(result, test_values, rtol=1e-05)

negative_values = np.array([[-1, -2]], dtype=K.floatx())
# cntk can't rebind the input shape, so create the model again to
# test different batch size
if (K.backend() == 'cntk'):
x2 = K.placeholder(ndim=2)
f = K.function([x2], [activations.elu(x2, 0.5)])
result = f([negative_values])[0]
true_result = (np.exp(negative_values) - 1) / 2

Expand All @@ -196,12 +191,6 @@ def test_selu():

negative_values = np.array([[-1, -2]], dtype=K.floatx())

# cntk can't rebind the input shape, so create the model again to
# test different batch size
if (K.backend() == 'cntk'):
x2 = K.placeholder(ndim=2)
f = K.function([x2], [activations.selu(x2)])

result = f([negative_values])[0]
true_result = (np.exp(negative_values) - 1) * scale * alpha

Expand Down

0 comments on commit 5938501

Please sign in to comment.