Skip to content

Commit

Permalink
Revise backend tests for dtype (#12544)
Browse files Browse the repository at this point in the history
* Revise backend tests for dtype

* Revise `check_dtype` for CNTK

* Revise `check_dtype` for CNTK

* Revise `check_dtype` for CNTK

* Improve `check_dtype`
  • Loading branch information
taehoonlee authored and fchollet committed Apr 3, 2019
1 parent 91efaaa commit d789bd9
Showing 1 changed file with 26 additions and 51 deletions.
77 changes: 26 additions & 51 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings

from keras import backend as K
from keras.backend import floatx, set_floatx, variable
from keras.utils.conv_utils import convert_kernel
from keras.backend import numpy_backend as KNP

Expand Down Expand Up @@ -44,10 +43,10 @@


def check_dtype(var, dtype):
if K.backend() == 'theano':
assert var.dtype == dtype
if K.backend() == 'tensorflow':
assert dtype in str(var.dtype.name)
else:
assert var.dtype.name == '%s_ref' % dtype
assert dtype in str(var.dtype)


def cntk_func_tensors(function_name, shapes_or_vals, **kwargs):
Expand Down Expand Up @@ -2020,61 +2019,37 @@ def test_in_test_phase(self, training):
check_two_tensor_operation('in_test_phase', (2, 3), (2, 3), WITH_NP,
training=training)

def test_setfloatx_incorrect_values(self):
@pytest.mark.parametrize('dtype', ['', 'beerfloat', 123])
def test_setfloatx_incorrect_values(self, dtype):
# Keep track of the old value
old_floatx = floatx()
# Try some incorrect values
initial = floatx()
for value in ['', 'beerfloat', 123]:
with pytest.raises(ValueError):
set_floatx(value)
assert floatx() == initial
# Restore old value
set_floatx(old_floatx)
old_floatx = K.floatx()
with pytest.raises(ValueError):
K.set_floatx(dtype)
assert K.floatx() == old_floatx

def test_setfloatx_correct_values(self):
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
def test_setfloatx_correct_values(self, dtype):
# Keep track of the old value
old_floatx = floatx()
old_floatx = K.floatx()
# Check correct values
for value in ['float16', 'float32', 'float64']:
set_floatx(value)
assert floatx() == value
K.set_floatx(dtype)
assert K.floatx() == dtype
# Make sure that changes to the global floatx are effectively
# taken into account by the backend.
check_dtype(K.variable([10]), dtype)
# Restore old value
set_floatx(old_floatx)

@pytest.mark.skipif((K.backend() == 'cntk'),
reason='cntk does not support float16')
def test_set_floatx(self):
"""
Make sure that changes to the global floatx are effectively
taken into account by the backend.
"""
# Keep track of the old value
old_floatx = floatx()

set_floatx('float16')
var = variable([10])
check_dtype(var, 'float16')
K.set_floatx(old_floatx)

set_floatx('float64')
var = variable([10])
check_dtype(var, 'float64')

# Restore old value
set_floatx(old_floatx)

def test_dtype(self):
assert K.dtype(K.variable(1, dtype='float64')) == 'float64'
assert K.dtype(K.variable(1, dtype='float32')) == 'float32'
assert K.dtype(K.variable(1, dtype='float16')) == 'float16'
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
def test_dtype(self, dtype):
assert K.dtype(K.variable(1, dtype=dtype)) == dtype

@pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported')
def test_variable_support_bool_dtype(self):
# Github issue: 7819
if K.backend() == 'tensorflow':
assert K.dtype(K.variable(1, dtype='int16')) == 'int16'
assert K.dtype(K.variable(False, dtype='bool')) == 'bool'
with pytest.raises(TypeError):
K.variable('', dtype='unsupported')
assert K.dtype(K.variable(1, dtype='int16')) == 'int16'
assert K.dtype(K.variable(False, dtype='bool')) == 'bool'
with pytest.raises(TypeError):
K.variable('', dtype='unsupported')

@pytest.mark.parametrize('shape', [(4, 2), (2, 3)])
def test_clip_supports_tensor_arguments(self, shape):
Expand Down

0 comments on commit d789bd9

Please sign in to comment.