Skip to content

Commit

Permalink
Parametrizing the metric tests to get one test per metric. (#11378)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse authored and fchollet committed Oct 14, 2018
1 parent 161f0a8 commit 02c1f88
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,19 @@
]


def test_metrics():
@pytest.mark.parametrize('metric', all_metrics)
def test_metrics(metric):
y_a = K.variable(np.random.random((6, 7)))
y_b = K.variable(np.random.random((6, 7)))
for metric in all_metrics:
output = metric(y_a, y_b)
print(metric.__name__)
assert K.eval(output).shape == (6,)
output = metric(y_a, y_b)
assert K.eval(output).shape == (6,)


def test_sparse_metrics():
for metric in all_sparse_metrics:
y_a = K.variable(np.random.randint(0, 7, (6,)), dtype=K.floatx())
y_b = K.variable(np.random.random((6, 7)), dtype=K.floatx())
assert K.eval(metric(y_a, y_b)).shape == (6,)
@pytest.mark.parametrize('metric', all_sparse_metrics)
def test_sparse_metrics(metric):
y_a = K.variable(np.random.randint(0, 7, (6,)), dtype=K.floatx())
y_b = K.variable(np.random.random((6, 7)), dtype=K.floatx())
assert K.eval(metric(y_a, y_b)).shape == (6,)


@pytest.mark.parametrize('shape', [(6,), (6, 3), (6, 3, 1)])
Expand Down

0 comments on commit 02c1f88

Please sign in to comment.