-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
135 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Module for Artificial Neural Network (ANN) Prediction. | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from .approximation import Approximation | ||
|
||
class ANN(Approximation): | ||
""" | ||
Feed-Forward Artifical Neural Network (ANN). | ||
: param int trained_epoch: number of already trained iterations. | ||
: param criterion: Loss definition (Mean Squared). | ||
: type criterion: torch.nn.modules.loss.MSELoss. | ||
Example: | ||
>>> import ezyrb | ||
>>> import numpy as np | ||
>>> x = np.random.uniform(-1, 1, size =(4, 2)) | ||
>>> y = np.array([np.sin(x[:, 0]), np.cos(x[:, 1]**3)]).T | ||
>>> ann = ezyrb.ANN() | ||
>>> ann.fit(x, y) | ||
>>> y_pred = ann.predict(x) | ||
>>> print(y) | ||
>>> print(y_pred) | ||
""" | ||
|
||
def __init__(self): | ||
self.trained_epoch = 0 | ||
self.criterion = torch.nn.MSELoss() | ||
|
||
|
||
def fit(self, points, values): | ||
""" | ||
Build the ANN given 'points' and 'values' and perform training. | ||
Given the number of neurons per layer, a feed-forward NN is defined. | ||
By default: | ||
- niter, number of training iterations: 20000; | ||
- activation function in each inner layer: Tanh; activation function | ||
at the output layer: Identity; | ||
- optimizer: Adam's method with default parameters | ||
(see, e.g., https://pytorch.org/docs/stable/optim.html); | ||
- loss: Mean Squared Loss. | ||
:param numpy.ndarray points: the coordinates of the given (training) points. | ||
:param numpy.ndarray values: the (training) values in the points. | ||
:return the training loss value at termination (after niter iterations). | ||
:rtype: float. | ||
""" | ||
layers = [points.shape[1], 10, 5, values.shape[1]] | ||
niter = 20000 | ||
arguments = [] | ||
for i in range(len(layers)-2): | ||
arguments.append(nn.Linear(layers[i], layers[i+1])) | ||
arguments.append(nn.Tanh()) | ||
arguments.append(nn.Linear(layers[len(layers)-2], layers[len(layers)-1])) | ||
arguments.append(nn.Identity()) | ||
self.model = nn.Sequential(*arguments) | ||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) | ||
points = torch.from_numpy(points).float() | ||
values = torch.from_numpy(values).float() | ||
for epoch in range(niter): | ||
y_pred = self.model(points) | ||
loss = self.criterion(y_pred, values) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
self.trained_epoch += niter | ||
return loss.item() | ||
|
||
def predict(self, new_point): | ||
""" | ||
Evaluate the ANN at given 'new_points'. | ||
:param array_like new_points: the coordinates of the given points. | ||
:return: the predicted values via the ANN. | ||
:rtype: numpy.ndarray | ||
""" | ||
new_point = np.array(new_point) | ||
new_point = torch.from_numpy(new_point).float() | ||
y_new = self.model(new_point) | ||
return y_new.detach().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
import torch.nn as nn | ||
|
||
from unittest import TestCase | ||
from ezyrb import ANN | ||
|
||
np.random.seed(17) | ||
|
||
def get_xy(): | ||
npts = 20 | ||
dinput = 4 | ||
|
||
inp = np.random.uniform(-1, 1, size=(npts, dinput)) | ||
out = np.array([ | ||
np.sin(inp[:, 0]) + np.sin(inp[:, 1]**2), | ||
np.cos(inp[:, 2]) + np.cos(inp[:, 3]**2) | ||
]).T | ||
|
||
return inp, out | ||
|
||
class TestANN(TestCase): | ||
def test_constructor_empty(self): | ||
ann = ANN() | ||
|
||
def test_fit_mono(self): | ||
x, y = get_xy() | ||
ann = ANN() | ||
ann.fit(x[:, 0].reshape(len(x),1), y[:, 0].reshape(len(y),1)) | ||
assert isinstance(ann.model, nn.Sequential) | ||
|
||
def test_fit(self): | ||
x, y = get_xy() | ||
ann = ANN() | ||
ann.fit(x, y) | ||
assert isinstance(ann.model, nn.Sequential) | ||
|
||
def test_predict_01(self): | ||
x, y = get_xy() | ||
ann = ANN() | ||
ann.fit(x, y) | ||
test_y = ann.predict(x) | ||
np.testing.assert_array_almost_equal(y, test_y, decimal=3) | ||
|
||
def test_predict_02(self): | ||
np.random.seed(1) | ||
x, y = get_xy() | ||
ann = ANN() | ||
ann.fit(x, y) | ||
test_y = ann.predict(x) | ||
np.testing.assert_array_almost_equal(y, test_y, decimal=3) |