diff --git a/tests/test_ann.py b/tests/test_ann.py index ca303148..a0b266df 100755 --- a/tests/test_ann.py +++ b/tests/test_ann.py @@ -36,22 +36,22 @@ def test_fit_01(self): def test_fit_02(self): x, y = get_xy() - ann = ANN([10, 5, 2], [nn.Tanh(), nn.Sigmoid(), nn.Tanh()], [20000,1e-8]) + ann = ANN([10, 5, 2], [nn.Tanh(), nn.Sigmoid(), nn.Tanh()], [20000, 1e-8]) ann.fit(x, y) assert isinstance(ann.model, nn.Sequential) def test_predict_01(self): np.random.seed(1) x, y = get_xy() - ann = ANN([10, 5],nn.Tanh(), 20000) + ann = ANN([10, 5], nn.Tanh(), 20) ann.fit(x, y) test_y = ann.predict(x) - np.testing.assert_array_almost_equal(y, test_y, decimal=3) + assert isinstance(test_y, np.ndarray) def test_predict_02(self): np.random.seed(1) x, y = get_xy() - ann = ANN([10, 5], nn.Tanh(), [20000,1e-8]) + ann = ANN([10, 5], nn.Tanh(), [20000, 1e-8]) ann.fit(x, y) test_y = ann.predict(x) np.testing.assert_array_almost_equal(y, test_y, decimal=3) @@ -62,4 +62,4 @@ def test_predict_03(self): ann = ANN([10, 5], nn.Tanh(), 1e-8) ann.fit(x, y) test_y = ann.predict(x) - np.testing.assert_array_almost_equal(y, test_y, decimal=3) \ No newline at end of file + np.testing.assert_array_almost_equal(y, test_y, decimal=3)