Skip to content

Commit

Permalink
new tests for class ann.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fandreuz committed Mar 5, 2021
1 parent 43fd50d commit 90c2eef
Showing 1 changed file with 73 additions and 2 deletions.
75 changes: 73 additions & 2 deletions tests/test_ann.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch.nn as nn
from torch import Tensor

from unittest import TestCase
from ezyrb import ANN
Expand All @@ -22,6 +23,31 @@ class TestANN(TestCase):
def test_constructor_empty(self):
ann = ANN([10, 5], nn.Tanh(), 20000)

def test_constrctor_loss_none(self):
ann = ANN([10, 5], nn.Tanh(), 20000, loss=None)
assert isinstance(ann.loss, nn.MSELoss)

def test_constructor_single_function(self):
passed_func = nn.Tanh()
ann = ANN([10, 5], passed_func, 20000)

assert isinstance(ann.function, list)
for func in ann.function:
assert func == passed_func

def test_constructor_layers(self):
ann = ANN([10, 5], nn.Tanh(), 20000)
assert ann.layers == [10, 5]

def test_constructor_stop_training(self):
ann = ANN([10, 5], nn.Tanh(), 20000)
assert ann.stop_training == 20000

def test_constructor_fields_initialized(self):
ann = ANN([10, 5], nn.Tanh(), 20000)
assert ann.loss_trend == []
assert ann.model is None

def test_fit_mono(self):
x, y = get_xy()
ann = ANN([10, 5], nn.Tanh(), [20000, 1e-5])
Expand All @@ -33,7 +59,7 @@ def test_fit_01(self):
ann = ANN([10, 5], nn.Tanh(), [20000, 1e-8])
ann.fit(x, y)
assert isinstance(ann.model, nn.Sequential)

def test_fit_02(self):
x, y = get_xy()
ann = ANN([10, 5, 2], [nn.Tanh(), nn.Sigmoid(), nn.Tanh()], [20000, 1e-8])
Expand All @@ -55,11 +81,56 @@ def test_predict_02(self):
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)

def test_predict_03(self):
np.random.seed(1)
x, y = get_xy()
ann = ANN([10, 5], nn.Tanh(), 1e-8)
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)

def test_convert_numpy_to_torch(self):
arr = [1.0, 2.0, 3.0, 4.0, 5.0]

ann = ANN([10, 5], nn.Tanh(), 20000)

value = ann._convert_numpy_to_torch(np.array(arr))
assert isinstance(value, Tensor)
for i in range(len(arr)):
assert value[i] == arr[i]

def test_convert_torch_to_numpy(self):
arr = [1.0, 2.0, 3.0, 4.0, 5.0]
tensor = torch.from_numpy(np.array(arr)).float()

ann = ANN([10, 5], nn.Tanh(), 20000)

value = ann._convert_torch_to_numpy(tensor)
assert isinstance(value, np.ndarray)
for i in range(len(arr)):
assert value[i] == arr[i]

def test_build_model(self):
passed_func = nn.Tanh()
ann = ANN([10, 5, 2], passed_func, 20000)

ann._build_model(np.array([4]), np.array([12,3]))

assert len(ann.model) == 6 + 1
for i in range(7):
layer = ann.model[i]
# the last layer, I keep the separated for clarity
if i == 6:
assert isinstance(layer, nn.Linear)
elif i % 2 == 0:
assert isinstance(layer, nn.Linear)
else:
assert layer == passed_func

# check input and output
assert ann.model[6].out_features == np.array([12, 3]).shape
assert ann.model[0].in_features == np.array([4]).shape

for i in range(0, 5, 2):
assert ann.model[i].out_features == ann.model[i].in_features

0 comments on commit 90c2eef

Please sign in to comment.