Skip to content

Commit

Permalink
Change test ann
Browse files Browse the repository at this point in the history
  • Loading branch information
ndem0 committed Mar 3, 2021
1 parent 7f35a0b commit 9ab8489
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/test_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ def test_fit_01(self):

def test_fit_02(self):
x, y = get_xy()
ann = ANN([10, 5, 2], [nn.Tanh(), nn.Sigmoid(), nn.Tanh()], [20000,1e-8])
ann = ANN([10, 5, 2], [nn.Tanh(), nn.Sigmoid(), nn.Tanh()], [20000, 1e-8])
ann.fit(x, y)
assert isinstance(ann.model, nn.Sequential)

def test_predict_01(self):
np.random.seed(1)
x, y = get_xy()
ann = ANN([10, 5],nn.Tanh(), 20000)
ann = ANN([10, 5], nn.Tanh(), 20)
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)
assert isinstance(test_y, np.ndarray)

def test_predict_02(self):
np.random.seed(1)
x, y = get_xy()
ann = ANN([10, 5], nn.Tanh(), [20000,1e-8])
ann = ANN([10, 5], nn.Tanh(), [20000, 1e-8])
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)
Expand All @@ -62,4 +62,4 @@ def test_predict_03(self):
ann = ANN([10, 5], nn.Tanh(), 1e-8)
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)

0 comments on commit 9ab8489

Please sign in to comment.