diff --git a/ezyrb/gpr.py b/ezyrb/gpr.py index bdd62036..5e515584 100644 --- a/ezyrb/gpr.py +++ b/ezyrb/gpr.py @@ -16,6 +16,17 @@ class GPR(Approximation): :cvar numpy.ndarray Y_sample: the array containing the output values, arranged by row. :cvar GPy.models.GPRegression model: the regression model. + + Example: + >>> import ezyrb + >>> import numpy as np + >>> x = np.random.uniform(-1, 1, size=(4, 2)) + >>> y = (np.sin(x[:, 0]) + np.cos(x[:, 1]**3)).reshape(-1, 1) + >>> gpr = ezyrb.GPR() + >>> gpr.fit(x, y) + >>> y_pred = gpr.predict(x) + >>> print(np.allclose(y, y_pred)) + """ def __init__(self): @@ -30,8 +41,12 @@ def fit(self, points, values, kern=None, optimization_restart=20): :param array_like points: the coordinates of the points. :param array_like values: the values in the points. """ - self.X_sample = np.atleast_2d(points) - self.Y_sample = np.atleast_2d(values) + self.X_sample = np.array(points) + self.Y_sample = np.array(values) + if self.X_sample.ndim == 1: + self.X_sample = self.X_sample.reshape(-1, 1) + if self.Y_sample.ndim == 1: + self.Y_sample = self.Y_sample.reshape(-1, 1) if kern is None: kern = GPy.kern.RBF( diff --git a/ezyrb/rbf.py b/ezyrb/rbf.py index bc54077e..cf0371d5 100644 --- a/ezyrb/rbf.py +++ b/ezyrb/rbf.py @@ -21,6 +21,18 @@ class RBF(Approximation): :cvar kernel: The radial basis function; the default is ‘multiquadric’. :cvar list interpolators: the RBF interpolators (the number of interpolators depenend by the dimensionality of the output) + + Example: + >>> import ezyrb + >>> import numpy as np + >>> + >>> x = np.random.uniform(-1, 1, size=(4, 2)) + >>> y = np.array([np.sin(x[:, 0]), np.cos(x[:, 1]**3)]).T + >>> rbf = ezyrb.RBF() + >>> rbf.fit(x, y) + >>> y_pred = rbf.predict(x) + >>> print(np.allclose(y, y_pred)) + """ def __init__(self, kernel='multiquadric', smooth=0): @@ -35,7 +47,7 @@ def fit(self, points, values): :param array_like values: the values in the points. """ self.interpolators = [] - for value in values: + for value in values.T: argument = np.hstack([points, value.reshape(-1, 1)]).T self.interpolators.append( Rbf(*argument, smooth=self.smooth, function=self.kernel)) @@ -48,4 +60,5 @@ def predict(self, new_point): :return: the interpolated values. :rtype: numpy.ndarray """ - return np.array([interp(*new_point) for interp in self.interpolators]) + new_point = np.array(new_point) + return np.array([interp(*new_point.T) for interp in self.interpolators]).T diff --git a/ezyrb/reducedordermodel.py b/ezyrb/reducedordermodel.py index 19217032..c09e0b90 100644 --- a/ezyrb/reducedordermodel.py +++ b/ezyrb/reducedordermodel.py @@ -12,12 +12,14 @@ def __init__(self, database, reduction, approximation): self.reduction = reduction self.approximation = approximation - def fit(self): + def fit(self, *args, **kwargs): """ Calculate reduced space """ - self.approximation.fit(self.database.parameters, - self.reduction.reduce(self.database.snapshots.T)) + self.approximation.fit( + self.database.parameters, + self.reduction.reduce(self.database.snapshots.T).T, + *args, **kwargs) return self diff --git a/tests/test_datasets/p_predsol_gpr.npy b/tests/test_datasets/p_predsol_gpr.npy index 9c6bfd65..72fff045 100644 Binary files a/tests/test_datasets/p_predsol_gpr.npy and b/tests/test_datasets/p_predsol_gpr.npy differ diff --git a/tests/test_reducedordermodel.py b/tests/test_reducedordermodel.py index 9130272f..f20ed730 100644 --- a/tests/test_reducedordermodel.py +++ b/tests/test_reducedordermodel.py @@ -26,6 +26,7 @@ def test_predict_01(self): np.testing.assert_allclose(pred_sol, pred_sol_tst, rtol=1e-4, atol=1e-5) def test_predict_02(self): + np.random.seed(117) pod = POD(method='svd', rank=4) gpr = GPR() db = Database(param, snapshots.T) @@ -37,9 +38,12 @@ def test_predict_03(self): pod = POD(method='svd', rank=3) gpr = GPR() db = Database(param, snapshots.T) + #rom = ROM(db, pod, RBF()).fit() + #pred_sol = rom.predict([-.45, -.45]) + #print(pred_sol) rom = ROM(db, pod, gpr).fit() - pred_sol = rom.predict([-.45, -.45]) - np.testing.assert_allclose(pred_sol, pred_sol_gpr, rtol=1e-4, atol=1e-5) + pred_sol = rom.predict(db.parameters[2]) + assert pred_sol.shape == db.snapshots[0].shape def test_loo_error(self): pod = POD()