Skip to content

Commit

Permalink
Directly subclass sklearn estimators in nonlinearities.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Niru Maheswaranathan committed Nov 1, 2016
1 parent c0e34f7 commit a069fa3
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 47 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ _build/
tags
dist/
*.swp
.coverage
htmlcov/
.cache/
19 changes: 13 additions & 6 deletions pyret/nonlinearities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from scipy.interpolate import interp1d
from functools import wraps
from itertools import zip_longest
from sklearn.base import BaseEstimator, RegressorMixin

try:
import GPy
except ImportError:
except ImportError: # pragma: no cover
print("You must install GPy (pip install GPy) to fit the GP regression nonlinearity.")

__all__ = ['Sigmoid', 'Binterp', 'GaussianProcess']


class Nonlinearity:
class Nonlinearity(BaseEstimator, RegressorMixin):
def plot(self, span=(-5, 5), n=100, **kwargs):
"""Creates a 1D plot of the nonlinearity
Expand Down Expand Up @@ -74,12 +75,17 @@ def __call__(self, x):


class GaussianProcess(Nonlinearity):
def __init__(self, variance=1., lengthscale=1.):
def __init__(self, variance=1., length_scale=None):
"""A nonlinearity fit using Gaussian Process (GP) regression.
Parameters
----------
variance (optional): float
length_scale (optional): float
"""

# Defines the kernel to use
self.kernel = GPy.kern.RBF(input_dim=1, variance=variance, lengthscale=lengthscale)
self.kernel = GPy.kern.RBF(input_dim=1, variance=variance, lengthscale=length_scale)

def fit(self, x, y):
"""Fits the GP regression model."""
Expand All @@ -89,11 +95,12 @@ def fit(self, x, y):

def predict(self, x):
"""Gets the mean prediction at the given values."""
return self.model.predict(x[:, np.newaxis])[0]
return self.predict_full(x)[0]

def predict_full(self, x):
"""Predicts the mean and variance at the given values."""
return self.model.predict(x[:, np.newaxis])
mean, stdev = self.model.predict(x[:, np.newaxis])
return np.squeeze(mean), np.squeeze(stdev)


class Sigmoid(Nonlinearity):
Expand Down
9 changes: 6 additions & 3 deletions tests/test_filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from pyret import filtertools as flt
from pyret.stimulustools import slicestim


def test_linear_prediction_one_dim():
"""Test method for computing linear prediction from a
"""Test method for computing linear prediction from a
filter to a one-dimensional stimulus.
"""
filt = np.random.randn(100,)
Expand All @@ -20,8 +21,9 @@ def test_linear_prediction_one_dim():
sl = slicestim(stim, filt.shape[0])
assert np.allclose(filt.reshape(1, -1).dot(sl), pred)


def test_linear_prediction_multi_dim():
"""Test method for computing linear prediction from a
"""Test method for computing linear prediction from a
filter to a multi-dimensional stimulus.
"""
for ndim in range(2, 4):
Expand All @@ -37,12 +39,14 @@ def test_linear_prediction_multi_dim():

assert np.allclose(tmp, pred)


def test_linear_prediction_raises():
"""Test raising ValueErrors with incorrect inputs"""
with pytest.raises(ValueError):
flt.linear_prediction(np.random.randn(10,), np.random.randn(10,2))
flt.linear_prediction(np.random.randn(10, 2), np.random.randn(10, 3))


def test_revco():
"""Test computation of a linear filter by reverse correlation"""
# Create fake filter
Expand All @@ -60,4 +64,3 @@ def test_revco():
filt = flt.revco(response, stimulus, filter_length, norm=True)
tol = 0.1
assert np.allclose(true, filt, atol=tol)

62 changes: 24 additions & 38 deletions tests/test_nonlinearities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,39 @@
(C) 2016 The Baccus Lab
"""
import numpy as np
import pytest
import pyret.nonlinearities as nln

from pyret import nonlinearities

def test_sigmoid():
"""Test the Sigmoid nonlinearity class"""
# True parameters
thresh = 0.5
slope = 2
peak = 1.5
baseline = 0.2
n = 1000 # Number of simulate data points
xscale = 2 # Scale factor for input range
noise = 0.1 # Standard deviation of AWGN
nonlinearities = [
(nln.Sigmoid, (), 1234, 0.1),
(nln.Binterp, (50,), 1234, 0.1),
(nln.GaussianProcess, (), 1234, 0.1),
(nln.Sigmoid, (), 5678, 0.5),
(nln.Binterp, (25,), 5678, 0.5),
(nln.GaussianProcess, (), 5678, 0.5),
]

# Simulate data
x = np.random.randn(n,) * xscale
y = nonlinearities.Sigmoid._sigmoid(x, thresh, slope,
peak, baseline) + np.random.randn(n,) * noise

# Fit nonlinearity and compare
y_hat = nonlinearities.Sigmoid().fit(x, y).predict(x)
norm = (np.linalg.norm(y - y_hat) / np.linalg.norm(y))
if (norm > (noise * 1.5)):
raise AssertionError("Fitting a Sigmoid nonlinearity seems " +
"to have failed, relative error = {0:#0.3f}".format(norm))
@pytest.mark.parametrize("nln_cls,args,seed,noise_stdev", nonlinearities)
def test_fitting(nln_cls, args, seed, noise_stdev):
"""Test the fit method of each nonlinearity"""
np.random.seed(seed)

def test_binterp():
"""Test the Binterp nonlinearity class"""
# True parameters
# simulate a noisy nonlinearity
thresh = 0.5
slope = 2
peak = 1.5
baseline = 0.2
n = 1000 # Number of simulate data points
xscale = 2 # Scale factor for input range
noise = 0.1 # Standard deviation of AWGN
nbins = 25 # Number of bins in the Binterp nonlienarity

# Simulate data
n = 1000 # Number of simulate data points
xscale = 2 # Scale factor for input range
x = np.random.randn(n,) * xscale
y = nonlinearities.Sigmoid._sigmoid(x, thresh, slope,
peak, baseline) + np.random.randn(n,) * noise
y = nln.Sigmoid._sigmoid(x, thresh, slope, peak, baseline)
y_obs = y + np.random.randn(n,) * noise_stdev

# Fit nonlinearity and compare
y_hat = nonlinearities.Binterp(nbins).fit(x, y).predict(x)
norm = (np.linalg.norm(y - y_hat) / np.linalg.norm(y))
if (norm > (noise * 1.5)):
raise AssertionError("Fitting a Sigmoid nonlinearity seems " +
"to have failed, relative error = {0:#0.3f}".format(norm))
# fit nonlinearity to the observed (noisy) data
y_hat = nln_cls(*args).fit(x, y_obs).predict(x)

# compute relative error
rel_error = np.linalg.norm(y - y_hat) / np.linalg.norm(y)
assert rel_error < (0.5 * noise_stdev)

0 comments on commit a069fa3

Please sign in to comment.