Skip to content

Commit

Permalink
Improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed Dec 31, 2020
1 parent d3bd940 commit 4849234
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 2 deletions.
10 changes: 8 additions & 2 deletions bayeso/utils/utils_covariance.py
Expand Up @@ -283,7 +283,9 @@ def validate_hyps_dict(hyps: dict, str_cov: str, dim: int,
if 'dof' not in hyps:
is_valid = False
else:
if hyps['dof'] <= 2.0:
if not isinstance(hyps['dof'], float):
is_valid = False
if isinstance(hyps['dof'], float) and hyps['dof'] <= 2.0:
hyps['dof'] = 2.00001

if str_cov in ('eq', 'se', 'matern32', 'matern52'):
Expand All @@ -307,7 +309,8 @@ def validate_hyps_dict(hyps: dict, str_cov: str, dim: int,
return hyps, is_valid

@utils_common.validate_types
def validate_hyps_arr(hyps: np.ndarray, str_cov: str, dim: int
def validate_hyps_arr(hyps: np.ndarray, str_cov: str, dim: int,
use_gp: bool=True
) -> constants.TYPING_TUPLE_ARRAY_BOOL:
"""
It validates hyperparameters array, `hyps`.
Expand All @@ -318,6 +321,8 @@ def validate_hyps_arr(hyps: np.ndarray, str_cov: str, dim: int
:type str_cov: str.
:param dim: dimensionality of the problem we are solving.
:type dim: int.
:param use_gp: flag for Gaussian process or Student-$t$ process.
:type use_gp: bool., optional
:returns: a tuple of valid hyperparameters and validity flag.
:rtype: (numpy.ndarray, bool.)
Expand All @@ -329,6 +334,7 @@ def validate_hyps_arr(hyps: np.ndarray, str_cov: str, dim: int
assert isinstance(hyps, np.ndarray)
assert isinstance(str_cov, str)
assert isinstance(dim, int)
assert isinstance(use_gp, bool)
assert str_cov in constants.ALLOWED_GP_COV

# is_valid = True
Expand Down
4 changes: 4 additions & 0 deletions tests/common/test_tp.py
Expand Up @@ -126,6 +126,10 @@ def test_sample_functions():
assert functions.shape[1] == num_points
assert functions.shape[0] == num_samples

functions = package_target.sample_functions(np.inf, mu, Sigma, num_samples=num_samples)
assert functions.shape[1] == num_points
assert functions.shape[0] == num_samples

def test_get_optimized_kernel_typing():
annos = package_target.get_optimized_kernel.__annotations__

Expand Down
93 changes: 93 additions & 0 deletions tests/common/test_utils_covariance.py
Expand Up @@ -23,6 +23,7 @@ def test_get_hyps_typing():
assert annos['str_cov'] == str
assert annos['dim'] == int
assert annos['use_ard'] == bool
assert annos['use_gp'] == bool
assert annos['return'] == dict

def test_get_hyps():
Expand All @@ -36,6 +37,8 @@ def test_get_hyps():
package_target.get_hyps('abc', 2)
with pytest.raises(AssertionError) as error:
package_target.get_hyps('se', 2, use_ard='abc')
with pytest.raises(AssertionError) as error:
package_target.get_hyps('se', 2, use_gp='abc')

cur_hyps = package_target.get_hyps('se', 2)
assert cur_hyps['noise'] == constants.GP_NOISE
Expand All @@ -58,12 +61,25 @@ def test_get_hyps():
assert cur_hyps['signal'] == 1.0
assert cur_hyps['lengthscales'] == 1.0

cur_hyps = package_target.get_hyps('matern32', 2, use_ard=False, use_gp=False)
assert cur_hyps['noise'] == constants.GP_NOISE
assert cur_hyps['signal'] == 1.0
assert cur_hyps['lengthscales'] == 1.0
assert cur_hyps['dof'] == 5.0

cur_hyps = package_target.get_hyps('matern32', 2, use_ard=True, use_gp=False)
assert cur_hyps['noise'] == constants.GP_NOISE
assert cur_hyps['signal'] == 1.0
assert np.all(cur_hyps['lengthscales'] == np.array([1.0, 1.0]))
assert cur_hyps['dof'] == 5.0

def test_get_range_hyps_typing():
annos = package_target.get_range_hyps.__annotations__

assert annos['str_cov'] == str
assert annos['dim'] == int
assert annos['use_ard'] == bool
assert annos['use_gp'] == bool
assert annos['fix_noise'] == bool
assert annos['return'] == list

Expand All @@ -78,6 +94,10 @@ def test_get_range_hyps():
package_target.get_range_hyps('se', 2, use_ard='abc')
with pytest.raises(AssertionError) as error:
package_target.get_range_hyps('se', 2, use_ard=1)
with pytest.raises(AssertionError) as error:
package_target.get_range_hyps('se', 2, use_gp='abc')
with pytest.raises(AssertionError) as error:
package_target.get_range_hyps('se', 2, use_gp=1)
with pytest.raises(AssertionError) as error:
package_target.get_range_hyps('se', 2, fix_noise=1)
with pytest.raises(AssertionError) as error:
Expand All @@ -89,12 +109,19 @@ def test_get_range_hyps():
assert isinstance(cur_range, list)
assert cur_range == [[0.001, 10.0], [0.01, 1000.0], [0.01, 1000.0]]

cur_range = package_target.get_range_hyps('se', 2, use_ard=False, fix_noise=False, use_gp=False)
print(type(cur_range))
print(cur_range)
assert isinstance(cur_range, list)
assert cur_range == [[0.001, 10.0], [2.00001, 200.0], [0.01, 1000.0], [0.01, 1000.0]]

def test_convert_hyps_typing():
annos = package_target.convert_hyps.__annotations__

assert annos['str_cov'] == str
assert annos['hyps'] == dict
assert annos['fix_noise'] == bool
assert annos['use_gp'] == bool
assert annos['return'] == np.ndarray

def test_convert_hyps():
Expand All @@ -112,6 +139,10 @@ def test_convert_hyps():
package_target.convert_hyps('abc', cur_hyps)
with pytest.raises(AssertionError) as error:
package_target.convert_hyps('se', dict(), fix_noise=1)
with pytest.raises(AssertionError) as error:
package_target.convert_hyps('se', cur_hyps, use_gp=1)
with pytest.raises(AssertionError) as error:
package_target.convert_hyps('se', cur_hyps, use_gp='abc')

converted_hyps = package_target.convert_hyps('se', cur_hyps, fix_noise=False)
assert len(converted_hyps.shape) == 1
Expand All @@ -126,11 +157,29 @@ def test_convert_hyps():
assert converted_hyps[0] == cur_hyps['signal']
assert (converted_hyps[1:] == cur_hyps['lengthscales']).all()

cur_hyps = {'noise': 0.1, 'signal': 1.0, 'lengthscales': np.array([2.0, 2.0]), 'dof': 100.0}
converted_hyps = package_target.convert_hyps('se', cur_hyps, fix_noise=False, use_gp=False)

assert len(converted_hyps.shape) == 1
assert converted_hyps.shape[0] == 5
assert converted_hyps[0] == cur_hyps['noise']
assert converted_hyps[1] == cur_hyps['dof']
assert converted_hyps[2] == cur_hyps['signal']
assert (converted_hyps[3:] == cur_hyps['lengthscales']).all()

converted_hyps = package_target.convert_hyps('se', cur_hyps, fix_noise=True, use_gp=False)
assert len(converted_hyps.shape) == 1
assert converted_hyps.shape[0] == 4
assert converted_hyps[0] == cur_hyps['dof']
assert converted_hyps[1] == cur_hyps['signal']
assert (converted_hyps[2:] == cur_hyps['lengthscales']).all()

def test_restore_hyps_typing():
annos = package_target.restore_hyps.__annotations__

assert annos['str_cov'] == str
assert annos['hyps'] == np.ndarray
assert annos['use_gp'] == bool
assert annos['fix_noise'] == bool
assert annos['noise'] == float
assert annos['return'] == dict
Expand All @@ -150,6 +199,10 @@ def test_restore_hyps():
package_target.restore_hyps('se', np.array([1.0, 1.0, 1.0]), fix_noise=1)
with pytest.raises(AssertionError) as error:
package_target.restore_hyps('se', np.array([1.0, 1.0, 1.0]), noise='abc')
with pytest.raises(AssertionError) as error:
package_target.restore_hyps('se', np.array([1.0, 1.0, 1.0]), use_gp=1)
with pytest.raises(AssertionError) as error:
package_target.restore_hyps('se', np.array([1.0, 1.0, 1.0]), use_gp='abc')

cur_hyps = np.array([0.1, 1.0, 1.0, 1.0, 1.0])
restored_hyps = package_target.restore_hyps('se', cur_hyps, fix_noise=False)
Expand All @@ -162,12 +215,27 @@ def test_restore_hyps():
assert restored_hyps['signal'] == cur_hyps[0]
assert (restored_hyps['lengthscales'] == cur_hyps[1:]).all()

cur_hyps = np.array([0.1, 100.0, 20.0, 1.0, 1.0, 1.0])
restored_hyps = package_target.restore_hyps('se', cur_hyps, fix_noise=False, use_gp=False)
assert restored_hyps['noise'] == cur_hyps[0]
assert restored_hyps['dof'] == cur_hyps[1]
assert restored_hyps['signal'] == cur_hyps[2]
assert (restored_hyps['lengthscales'] == cur_hyps[3:]).all()

cur_hyps = np.array([100.0, 20.0, 1.0, 1.0, 1.0])
restored_hyps = package_target.restore_hyps('se', cur_hyps, fix_noise=True, use_gp=False)
assert restored_hyps['noise'] == constants.GP_NOISE
assert restored_hyps['dof'] == cur_hyps[0]
assert restored_hyps['signal'] == cur_hyps[1]
assert (restored_hyps['lengthscales'] == cur_hyps[2:]).all()

def test_validate_hyps_dict_typing():
annos = package_target.validate_hyps_dict.__annotations__

assert annos['hyps'] == dict
assert annos['str_cov'] == str
assert annos['dim'] == int
assert annos['use_gp'] == bool
assert annos['return'] == typing.Tuple[dict, bool]

def test_validate_hyps_dict():
Expand All @@ -181,6 +249,10 @@ def test_validate_hyps_dict():
_, is_valid = package_target.validate_hyps_dict(cur_hyps, 'abc', num_dim)
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, 'abc')
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, num_dim, use_gp=1)
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, num_dim, use_gp='abc')

cur_hyps = package_target.get_hyps(str_cov, num_dim)
cur_hyps.pop('noise')
Expand Down Expand Up @@ -228,12 +300,31 @@ def test_validate_hyps_dict():
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, num_dim)
assert is_valid == True

cur_hyps = package_target.get_hyps(str_cov, num_dim, use_gp=False)
cur_hyps['signal'] = 'abc'
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, num_dim)
assert is_valid == True

cur_hyps = package_target.get_hyps(str_cov, num_dim, use_gp=False)
cur_hyps['dof'] = 'abc'
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, num_dim, use_gp=False)
assert is_valid == True

cur_hyps = package_target.get_hyps(str_cov, num_dim, use_ard=False, use_gp=False)
cur_hyps['lengthscales'] = 'abc'
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_dict(cur_hyps, str_cov, num_dim, use_gp=False)
assert is_valid == True

def test_validate_hyps_arr_typing():
annos = package_target.validate_hyps_arr.__annotations__

assert annos['hyps'] == np.ndarray
assert annos['str_cov'] == str
assert annos['dim'] == int
assert annos['use_gp'] == bool
assert annos['return'] == typing.Tuple[np.ndarray, bool]

def test_validate_hyps_arr():
Expand All @@ -248,6 +339,8 @@ def test_validate_hyps_arr():
_, is_valid = package_target.validate_hyps_arr(cur_hyps, 'abc', num_dim)
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_arr(cur_hyps, str_cov, 'abc')
with pytest.raises(AssertionError) as error:
_, is_valid = package_target.validate_hyps_arr(cur_hyps, str_cov, num_dim, use_gp='abc')

def test_check_str_cov_typing():
annos = package_target.check_str_cov.__annotations__
Expand Down

0 comments on commit 4849234

Please sign in to comment.