diff --git a/ezyrb/gpr.py b/ezyrb/gpr.py index 5e515584..5837afa0 100644 --- a/ezyrb/gpr.py +++ b/ezyrb/gpr.py @@ -34,12 +34,17 @@ def __init__(self): self.Y_sample = None self.model = None - def fit(self, points, values, kern=None, optimization_restart=20): + def fit(self, points, values, kern=None, normalizer=True, optimization_restart=20): """ Construct the regression given `points` and `values`. :param array_like points: the coordinates of the points. :param array_like values: the values in the points. + :param GPy.kern kern: kernel object from GPy. + :param bool normalizer: whether to normilize `values` or not. + Defaults to True. + :param int optimization_restart: number of restarts for the + optimization. Defaults to 20. """ self.X_sample = np.array(points) self.Y_sample = np.array(values) @@ -57,7 +62,7 @@ def fit(self, points, values, kern=None, optimization_restart=20): self.X_sample, self.Y_sample, kern, - normalizer=True) + normalizer=normalizer) self.model.optimize_restarts(optimization_restart, verbose=False)