Skip to content

Commit

Permalink
Update tests for trees
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed Aug 18, 2021
1 parent a632a0a commit c2bc141
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 2 deletions.
1 change: 1 addition & 0 deletions bayeso/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
TYPING_TUPLE_ARRAY_FLOAT = typing.Tuple[TYPE_ARR, float]
TYPING_TUPLE_TWO_ARRAYS = typing.Tuple[TYPE_ARR, TYPE_ARR]
TYPING_TUPLE_TWO_ARRAYS_DICT = typing.Tuple[TYPE_ARR, TYPE_ARR, dict]
TYPING_TUPLE_TWO_FLOATS = typing.Tuple[float, float]
TYPING_TUPLE_THREE_ARRAYS = typing.Tuple[TYPE_ARR, TYPE_ARR, TYPE_ARR]
TYPING_TUPLE_FIVE_ARRAYS = typing.Tuple[TYPE_ARR, TYPE_ARR, TYPE_ARR, TYPE_ARR, TYPE_ARR]
TYPING_TUPLE_FLOAT_THREE_ARRAYS = typing.Tuple[float, TYPE_ARR, TYPE_ARR, TYPE_ARR]
Expand Down
4 changes: 2 additions & 2 deletions bayeso/trees/trees_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def split(
split(node['right'], depth_max, size_min_leaf, num_features, split_random_location, cur_depth + 1)

@utils_common.validate_types
def _predict_by_tree(bx: np.ndarray, tree: dict) -> constants.TYPING_TUPLE_TWO_ARRAYS:
def _predict_by_tree(bx: np.ndarray, tree: dict) -> constants.TYPING_TUPLE_TWO_FLOATS:
assert isinstance(bx, np.ndarray)
assert isinstance(tree, dict)

Expand All @@ -237,7 +237,7 @@ def _predict_by_tree(bx: np.ndarray, tree: dict) -> constants.TYPING_TUPLE_TWO_A
return np.mean(cur_Y), np.std(cur_Y)

@utils_common.validate_types
def _predict_by_trees(bx: np.ndarray, list_trees: list) -> constants.TYPING_TUPLE_TWO_ARRAYS:
def _predict_by_trees(bx: np.ndarray, list_trees: list) -> constants.TYPING_TUPLE_TWO_FLOATS:
assert isinstance(bx, np.ndarray)
assert isinstance(list_trees, list)

Expand Down
281 changes: 281 additions & 0 deletions tests/common/test_trees_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,284 @@ def test__split():

assert np.all(dict_split['left_right'][1][0][0] == np.array([36, 37, 38, 39]))
assert np.abs(dict_split['left_right'][1][0][1] - np.array([0.54256004])) < TEST_EPSILON

dict_split = package_target._split(X, Y, num_features, True)
print(dict_split)
print(dict_split['index'])
print(dict_split['value'])
print(dict_split['left_right'])

assert isinstance(dict_split, dict)
assert dict_split['index'] == 0
assert dict_split['value'] == 0.2543869879098266

X = np.ones(X.shape)

dict_split = package_target._split(X, Y, num_features, True)
print(dict_split)
print(dict_split['index'])
print(dict_split['value'])
print(dict_split['left_right'])

assert isinstance(dict_split, dict)
assert dict_split['index'] == 1
assert dict_split['value'] == 1.0

def test_split_typing():
annos = package_target.split.__annotations__

assert annos['node'] == dict
assert annos['depth_max'] == int
assert annos['size_min_leaf'] == int
assert annos['num_features'] == int
assert annos['split_random_location'] == bool
assert annos['cur_depth'] == int
assert annos['return'] == constants.TYPE_NONE

def test_split():
np.random.seed(42)

X = np.reshape(np.arange(0, 40), (10, 4))
Y = np.random.randn(10, 1)
depth_max = 4
size_min_leaf = 2
num_features = 2
split_random_location = False

node = package_target._split(X, Y, num_features, split_random_location)

with pytest.raises(AssertionError) as error:
package_target.split(node, depth_max, size_min_leaf, num_features, split_random_location, 'abc')
with pytest.raises(AssertionError) as error:
package_target.split(node, depth_max, size_min_leaf, num_features, split_random_location, 1.0)
with pytest.raises(AssertionError) as error:
package_target.split(node, depth_max, size_min_leaf, num_features, 'abc', 1)
with pytest.raises(AssertionError) as error:
package_target.split(node, depth_max, size_min_leaf, 'abc', split_random_location, 1)
with pytest.raises(AssertionError) as error:
package_target.split(node, depth_max, 'abc', num_features, split_random_location, 1)
with pytest.raises(AssertionError) as error:
package_target.split(node, 'abc', size_min_leaf, num_features, split_random_location, 1)
with pytest.raises(AssertionError) as error:
package_target.split(X, depth_max, size_min_leaf, num_features, split_random_location, 1)
with pytest.raises(AssertionError) as error:
package_target.split('abc', depth_max, size_min_leaf, num_features, split_random_location, 1)

package_target.split(node, depth_max, size_min_leaf, num_features, split_random_location, 1)
assert isinstance(node, dict)

def test__predict_by_tree_typing():
annos = package_target._predict_by_tree.__annotations__

assert annos['bx'] == np.ndarray
assert annos['tree'] == dict
assert annos['return'] == constants.TYPING_TUPLE_TWO_FLOATS

def test__predict_by_tree():
np.random.seed(42)

X = np.reshape(np.arange(0, 40), (10, 4))
Y = np.random.randn(10, 1)
depth_max = 4
size_min_leaf = 2
num_features = 2
split_random_location = True

node = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node, depth_max, size_min_leaf, num_features, split_random_location, 1)

with pytest.raises(AssertionError) as error:
package_target._predict_by_tree(np.array([4.0, 2.0, 3.0, 1.0]), 'abc')
with pytest.raises(AssertionError) as error:
package_target._predict_by_tree(X, node)
with pytest.raises(AssertionError) as error:
package_target._predict_by_tree('abc', node)

mean, std = package_target._predict_by_tree(np.array([4.0, 2.0, 3.0, 1.0]), node)
print(mean)
print(std)

assert mean == 0.179224925920024
assert std == 0.31748922709120864

def test__predict_by_trees_typing():
annos = package_target._predict_by_trees.__annotations__

assert annos['bx'] == np.ndarray
assert annos['list_trees'] == list
assert annos['return'] == constants.TYPING_TUPLE_TWO_FLOATS

def test__predict_by_trees():
np.random.seed(42)

X = np.reshape(np.arange(0, 40), (10, 4))
Y = np.random.randn(10, 1)
depth_max = 4
size_min_leaf = 2
num_features = 2
split_random_location = True

node_1 = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node_1, depth_max, size_min_leaf, num_features, split_random_location, 1)

node_2 = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node_2, depth_max, size_min_leaf, num_features, split_random_location, 1)

node_3 = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node_3, depth_max, size_min_leaf, num_features, split_random_location, 1)

list_trees = [node_1, node_2, node_3]

with pytest.raises(AssertionError) as error:
package_target._predict_by_trees(np.array([4.0, 2.0, 3.0, 1.0]), node_1)
with pytest.raises(AssertionError) as error:
package_target._predict_by_trees(np.array([4.0, 2.0, 3.0, 1.0]), 'abc')
with pytest.raises(AssertionError) as error:
package_target._predict_by_trees(X, list_trees)
with pytest.raises(AssertionError) as error:
package_target._predict_by_trees('abc', list_trees)

mean, std = package_target._predict_by_trees(np.array([4.0, 2.0, 3.0, 1.0]), list_trees)
print(mean)
print(std)

assert mean == 0.12544669602080652
assert std == 0.3333040901154691

def test_predict_by_trees_typing():
annos = package_target.predict_by_trees.__annotations__

assert annos['X'] == np.ndarray
assert annos['list_trees'] == list
assert annos['return'] == constants.TYPING_TUPLE_TWO_ARRAYS

def test_predict_by_trees():
np.random.seed(42)

X = np.reshape(np.arange(0, 40), (10, 4))
Y = np.random.randn(10, 1)
depth_max = 4
size_min_leaf = 2
num_features = 2
split_random_location = True

node_1 = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node_1, depth_max, size_min_leaf, num_features, split_random_location, 1)

node_2 = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node_2, depth_max, size_min_leaf, num_features, split_random_location, 1)

node_3 = package_target._split(X, Y, num_features, split_random_location)
package_target.split(node_3, depth_max, size_min_leaf, num_features, split_random_location, 1)

list_trees = [node_1, node_2, node_3]

with pytest.raises(AssertionError) as error:
package_target.predict_by_trees(np.array([4.0, 2.0, 3.0, 1.0]), node_1)
with pytest.raises(AssertionError) as error:
package_target.predict_by_trees(np.array([4.0, 2.0, 3.0, 1.0]), 'abc')
with pytest.raises(AssertionError) as error:
package_target.predict_by_trees(np.array([4.0, 2.0, 3.0, 1.0]), list_trees)
with pytest.raises(AssertionError) as error:
package_target.predict_by_trees('abc', list_trees)

means, stds = package_target.predict_by_trees(X, list_trees)
print(means)
print(stds)

means_truth = np.array([
[0.33710618],
[0.1254467],
[0.68947573],
[0.9866563],
[0.16257799],
[0.16258346],
[0.85148454],
[0.55084716],
[0.30721653],
[0.30721653],
])

stds_truth = np.array([
[0.29842457],
[0.33330409],
[0.44398864],
[0.72536523],
[0.74232577],
[0.74232284],
[0.83388663],
[0.5615399],
[0.64331582],
[0.64331582],
])

assert isinstance(means, np.ndarray)
assert isinstance(stds, np.ndarray)
assert len(means.shape) == 2
assert len(stds.shape) == 2
assert means.shape[0] == stds.shape[0] == X.shape[0]
assert means.shape[1] == stds.shape[1] == 1

assert np.all(np.abs(means - means_truth) < TEST_EPSILON)
assert np.all(np.abs(stds - stds_truth) < TEST_EPSILON)

def test_compute_sigma_typing():
annos = package_target.compute_sigma.__annotations__

assert annos['preds_mu_leaf'] == np.ndarray
assert annos['preds_sigma_leaf'] == np.ndarray
assert annos['min_sigma'] == float
assert annos['return'] == np.ndarray

def test_compute_sigma():
means_leaf = np.array([
1.0,
2.0,
3.0,
9.0,
8.0,
4.0,
5.0,
6.0,
7.0,
10.0,
])

stds_leaf = np.array([
-1.0,
0.0,
1.0,
2.0,
1.0,
1.0,
4.0,
3.0,
4.0,
-2.0,
])
min_sigma = 0.0

with pytest.raises(AssertionError) as error:
package_target.compute_sigma(means_leaf, stds_leaf, min_sigma='abc')
with pytest.raises(AssertionError) as error:
package_target.compute_sigma(means_leaf, stds_leaf, min_sigma=4)
with pytest.raises(AssertionError) as error:
package_target.compute_sigma(means_leaf, 'abc', min_sigma=min_sigma)
with pytest.raises(AssertionError) as error:
package_target.compute_sigma(means_leaf, np.array([[1.0], [2.0], [1.0]]), min_sigma=min_sigma)
with pytest.raises(AssertionError) as error:
package_target.compute_sigma('abc', stds_leaf, min_sigma=min_sigma)
with pytest.raises(AssertionError) as error:
package_target.compute_sigma(np.array([[1.0], [2.0], [1.0]]), stds_leaf, min_sigma=min_sigma)

sigma = package_target.compute_sigma(means_leaf, stds_leaf, min_sigma=min_sigma)
print(sigma)

sigma_truth = np.mean(means_leaf**2 + np.maximum(stds_leaf, min_sigma)**2)
print(sigma_truth)
sigma_truth -= np.mean(means_leaf)**2
print(sigma_truth)
sigma_truth = np.sqrt(sigma_truth)
print(sigma_truth)

assert sigma == sigma_truth == 3.612478373637688

0 comments on commit c2bc141

Please sign in to comment.