Skip to content

Commit

Permalink
Merge pull request #133 from ggurioli/ANNbranch
Browse files Browse the repository at this point in the history
Adding a line in ann.py
  • Loading branch information
mtezzele committed Mar 5, 2021
2 parents 43fd50d + 9e0668e commit aaa2a61
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ezyrb/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ANN(Approximation):
Example:
>>> import ezyrb
>>> import numpy as np
>>> import torch.nn as nn
>>> x = np.random.uniform(-1, 1, size =(4, 2))
>>> y = np.array([np.sin(x[:, 0]), np.cos(x[:, 1]**3)]).T
>>> ann = ezyrb.ANN([10, 5], nn.Tanh(), [20000,1e-5])
Expand Down Expand Up @@ -113,7 +114,7 @@ def fit(self, points, values):
:param numpy.ndarray values: the (training) values in the points.
"""

self._build_model(points,values)
self._build_model(points, values)
self.optimizer = torch.optim.Adam(self.model.parameters())

points = self._convert_numpy_to_torch(points)
Expand Down

0 comments on commit aaa2a61

Please sign in to comment.