Skip to content

Commit

Permalink
Merge 3806f3d into f3fe324
Browse files Browse the repository at this point in the history
  • Loading branch information
fandreuz committed Mar 5, 2021
2 parents f3fe324 + 3806f3d commit 1fbf97c
Showing 1 changed file with 74 additions and 2 deletions.
76 changes: 74 additions & 2 deletions tests/test_ann.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch.nn as nn
from torch import Tensor, from_numpy

from unittest import TestCase
from ezyrb import ANN
Expand All @@ -22,6 +23,32 @@ class TestANN(TestCase):
def test_constructor_empty(self):
ann = ANN([10, 5], nn.Tanh(), 20000)

def test_constrctor_loss_none(self):
ann = ANN([10, 5], nn.Tanh(), 20000, loss=None)
assert isinstance(ann.loss, nn.MSELoss)

def test_constructor_single_function(self):
passed_func = nn.Tanh()
ann = ANN([10, 5], passed_func, 20000)

assert isinstance(ann.function, list)
for func in ann.function:
assert func == passed_func

def test_constructor_layers(self):
ann = ANN([10, 5], nn.Tanh(), 20000)
assert ann.layers == [10, 5]

def test_constructor_stop_training(self):
ann = ANN([10, 5], nn.Tanh(), 20000)
assert isinstance(ann.stop_training, list)
assert ann.stop_training == [20000]

def test_constructor_fields_initialized(self):
ann = ANN([10, 5], nn.Tanh(), 20000)
assert ann.loss_trend == []
assert ann.model is None

def test_fit_mono(self):
x, y = get_xy()
ann = ANN([10, 5], nn.Tanh(), [20000, 1e-5])
Expand All @@ -33,7 +60,7 @@ def test_fit_01(self):
ann = ANN([10, 5], nn.Tanh(), [20000, 1e-8])
ann.fit(x, y)
assert isinstance(ann.model, nn.Sequential)

def test_fit_02(self):
x, y = get_xy()
ann = ANN([10, 5, 2], [nn.Tanh(), nn.Sigmoid(), nn.Tanh()], [20000, 1e-8])
Expand All @@ -55,11 +82,56 @@ def test_predict_02(self):
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)

def test_predict_03(self):
np.random.seed(1)
x, y = get_xy()
ann = ANN([10, 5], nn.Tanh(), 1e-8)
ann.fit(x, y)
test_y = ann.predict(x)
np.testing.assert_array_almost_equal(y, test_y, decimal=3)

def test_convert_numpy_to_torch(self):
arr = [1.0, 2.0, 3.0, 4.0, 5.0]

ann = ANN([10, 5], nn.Tanh(), 20000)

value = ann._convert_numpy_to_torch(np.array(arr))
assert isinstance(value, Tensor)
for i in range(len(arr)):
assert value[i] == arr[i]

def test_convert_torch_to_numpy(self):
arr = [1.0, 2.0, 3.0, 4.0, 5.0]
tensor = from_numpy(np.array(arr)).float()

ann = ANN([10, 5], nn.Tanh(), 20000)

value = ann._convert_torch_to_numpy(tensor)
assert isinstance(value, np.ndarray)
for i in range(len(arr)):
assert value[i] == arr[i]

def test_build_model(self):
passed_func = nn.Tanh()
ann = ANN([10, 5, 2], passed_func, 20000)

ann._build_model(np.array([[1,2],[3,4]]), np.array([[5,6],[7,8]]))

assert len(ann.model) == 6 + 1
for i in range(7):
layer = ann.model[i]
# the last layer, I keep the separated for clarity
if i == 6:
assert isinstance(layer, nn.Linear)
elif i % 2 == 0:
assert isinstance(layer, nn.Linear)
else:
assert layer == passed_func

# check input and output
assert ann.model[6].out_features == np.array([[5,6],[7,8]]).shape[1]
assert ann.model[0].in_features == np.array([[1,2],[3,4]]).shape[1]

for i in range(0, 5, 2):
assert ann.model[i].out_features == ann.model[i+2].in_features

0 comments on commit 1fbf97c

Please sign in to comment.