Skip to content

Commit

Permalink
Merge d85f208 into b9e787f
Browse files Browse the repository at this point in the history
  • Loading branch information
ndem0 committed Jun 24, 2020
2 parents b9e787f + d85f208 commit f0dbde9
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 9 deletions.
19 changes: 17 additions & 2 deletions ezyrb/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ class GPR(Approximation):
:cvar numpy.ndarray Y_sample: the array containing the output values,
arranged by row.
:cvar GPy.models.GPRegression model: the regression model.
Example:
>>> import ezyrb
>>> import numpy as np
>>> x = np.random.uniform(-1, 1, size=(4, 2))
>>> y = (np.sin(x[:, 0]) + np.cos(x[:, 1]**3)).reshape(-1, 1)
>>> gpr = ezyrb.GPR()
>>> gpr.fit(x, y)
>>> y_pred = gpr.predict(x)
>>> print(np.allclose(y, y_pred))
"""

def __init__(self):
Expand All @@ -30,8 +41,12 @@ def fit(self, points, values, kern=None, optimization_restart=20):
:param array_like points: the coordinates of the points.
:param array_like values: the values in the points.
"""
self.X_sample = np.atleast_2d(points)
self.Y_sample = np.atleast_2d(values)
self.X_sample = np.array(points)
self.Y_sample = np.array(values)
if self.X_sample.ndim == 1:
self.X_sample = self.X_sample.reshape(-1, 1)
if self.Y_sample.ndim == 1:
self.Y_sample = self.Y_sample.reshape(-1, 1)

if kern is None:
kern = GPy.kern.RBF(
Expand Down
17 changes: 15 additions & 2 deletions ezyrb/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class RBF(Approximation):
:cvar kernel: The radial basis function; the default is ‘multiquadric’.
:cvar list interpolators: the RBF interpolators (the number of
interpolators depenend by the dimensionality of the output)
Example:
>>> import ezyrb
>>> import numpy as np
>>>
>>> x = np.random.uniform(-1, 1, size=(4, 2))
>>> y = np.array([np.sin(x[:, 0]), np.cos(x[:, 1]**3)]).T
>>> rbf = ezyrb.RBF()
>>> rbf.fit(x, y)
>>> y_pred = rbf.predict(x)
>>> print(np.allclose(y, y_pred))
"""

def __init__(self, kernel='multiquadric', smooth=0):
Expand All @@ -35,7 +47,7 @@ def fit(self, points, values):
:param array_like values: the values in the points.
"""
self.interpolators = []
for value in values:
for value in values.T:
argument = np.hstack([points, value.reshape(-1, 1)]).T
self.interpolators.append(
Rbf(*argument, smooth=self.smooth, function=self.kernel))
Expand All @@ -48,4 +60,5 @@ def predict(self, new_point):
:return: the interpolated values.
:rtype: numpy.ndarray
"""
return np.array([interp(*new_point) for interp in self.interpolators])
new_point = np.array(new_point)
return np.array([interp(*new_point.T) for interp in self.interpolators]).T
8 changes: 5 additions & 3 deletions ezyrb/reducedordermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ def __init__(self, database, reduction, approximation):
self.reduction = reduction
self.approximation = approximation

def fit(self):
def fit(self, *args, **kwargs):
"""
Calculate reduced space
"""
self.approximation.fit(self.database.parameters,
self.reduction.reduce(self.database.snapshots.T))
self.approximation.fit(
self.database.parameters,
self.reduction.reduce(self.database.snapshots.T).T,
*args, **kwargs)

return self

Expand Down
Binary file modified tests/test_datasets/p_predsol_gpr.npy
Binary file not shown.
8 changes: 6 additions & 2 deletions tests/test_reducedordermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_predict_01(self):
np.testing.assert_allclose(pred_sol, pred_sol_tst, rtol=1e-4, atol=1e-5)

def test_predict_02(self):
np.random.seed(117)
pod = POD(method='svd', rank=4)
gpr = GPR()
db = Database(param, snapshots.T)
Expand All @@ -37,9 +38,12 @@ def test_predict_03(self):
pod = POD(method='svd', rank=3)
gpr = GPR()
db = Database(param, snapshots.T)
#rom = ROM(db, pod, RBF()).fit()
#pred_sol = rom.predict([-.45, -.45])
#print(pred_sol)
rom = ROM(db, pod, gpr).fit()
pred_sol = rom.predict([-.45, -.45])
np.testing.assert_allclose(pred_sol, pred_sol_gpr, rtol=1e-4, atol=1e-5)
pred_sol = rom.predict(db.parameters[2])
assert pred_sol.shape == db.snapshots[0].shape

def test_loo_error(self):
pod = POD()
Expand Down

0 comments on commit f0dbde9

Please sign in to comment.