Skip to content

Commit

Permalink
Adds a gaussian process (GP) regression nonlinearity. (#71)
Browse files Browse the repository at this point in the history
* Adds a gaussian process (GP) regression nonlinearity.

* Adds a gaussian process (GP) regression nonlinearity.
  • Loading branch information
nirum committed Sep 10, 2016
1 parent c369401 commit b0828ca
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion pyret/nonlinearities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from functools import wraps
from itertools import zip_longest

__all__ = ['Sigmoid', 'Binterp']
try:
import GPy
except ImportError:
print("You must install GPy (pip install GPy) to fit the GP regression nonlinearity.")

__all__ = ['Sigmoid', 'Binterp', 'GaussianProcess']


class Nonlinearity:
Expand Down Expand Up @@ -50,6 +55,29 @@ def __call__(self, x):
return self.predict(x)


class GaussianProcess(Nonlinearity):
def __init__(self, variance=1., lengthscale=1.):
"""A nonlinearity fit using Gaussian Process (GP) regression.
"""

# Defines the kernel to use
self.kernel = GPy.kern.RBF(input_dim=1, variance=variance, lengthscale=lengthscale)

def fit(self, x, y):
"""Fits the GP regression model."""
self.model = GPy.models.GPRegression(x[:, np.newaxis], y[:, np.newaxis], self.kernel)
self.model.optimize()
return self

def predict(self, x):
"""Gets the mean prediction at the given values."""
return self.model.predict(x[:, np.newaxis])[0]

def predict_full(self, x):
"""Predicts the mean and variance at the given values."""
return self.model.predict(x[:, np.newaxis])


class Sigmoid(Nonlinearity):
def __init__(self, baseline=0., peak=1., slope=1., threshold=0.):
"""A sigmoidal nonlinearity
Expand Down

0 comments on commit b0828ca

Please sign in to comment.