Skip to content

Commit

Permalink
Updates to nonlinearity module. Removed all old unused functions,
Browse files Browse the repository at this point in the history
changed to using a class-based system. Users create nonlinearities of
the given type, then can call fit() and predict() to learn from data and
fit the nonlinearity to new values, respectively.
  • Loading branch information
bnaecker committed Apr 9, 2016
1 parent 744d985 commit 611ee41
Showing 1 changed file with 133 additions and 179 deletions.
312 changes: 133 additions & 179 deletions pyret/nonlinearities.py
Original file line number Diff line number Diff line change
@@ -1,194 +1,148 @@
"""
Tools for fitting nonlinear functions to data
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d
from functools import wraps
from itertools import zip_longest

__all__ = ['gaussian', 'sigmoid', 'dprime', 'fitgaussian',
'fitsigmoid', 'estdprime', 'estnln']


def gaussian(x, mu, sigma):
"""
A 1D (unnormalized) gaussian function
"""

return np.exp(-0.5 * ((x-mu) / sigma)**2)


def sigmoid(x, threshold, slope, peak, offset):
"""
A sigmoidal nonlinearity
"""

return offset + peak / (1 + np.exp(-slope*(x - threshold)))


def dprime(p0, p1):
"""
compute d' between two distributions given mean / standard deviation
Parameters
----------
p0 : (float, float)
Mean and standard deviation for the first distribution
p1 : (float, float)
Mean and standard deviation for the second distribution
"""
return (p1[0] - p0[0]) / np.sqrt(p1[1]**2 + p0[1]**2)


def fitgaussian(xpts, ypts, p0=None):
"""
Fit a gaussian function to noisy data
Parameters
----------
xpts : array_like
x-values of the data to fit
ypts : array_like
y-values of the data to fit
Returns
-------
popt : array_like
The best-fit sigmoidal parameters (threshold, slope, peak, and offset)
yhat : array_like
The estimated y-values at the given locations in xpts
pcov [matrix]:
"""

# estimate initial conditions
if p0 is None:
p0 = (np.mean(xpts), 5*np.mean(np.diff(xpts)))

# normalize the max to have value 1
scalefactor = float(np.max(ypts))
ypts = ypts / scalefactor

# get parameters
popt, pcov = curve_fit(gaussian, xpts, ypts, p0)

# evaluate fit
yhat = gaussian(xpts, *popt) * scalefactor

return popt, yhat, pcov


def fitsigmoid(xpts, ypts, **kwargs):
"""
Fit a sigmoidal function to noisy data
Parameters
----------
xpts : array_like
x-values of the data to fit
ypts : array_like
y-values of the data to fit
kwargs
Optional keyword arguments are passed to `scipy.optimize.curve_fit`, and
can be used to control the fitting procedure more carefully. This may be
needed, e.g., if the nonlinearities are quite noisy.
Returns
-------
popt : array_like
The best-fit sigmoidal parameters (threshold, slope, peak, and offset)
yhat : array_like
The estimated y-values at the given locations in xpts
pcov : array_like
"""

# estimate initial conditions
p0 = (np.mean(xpts), 1, np.max(ypts), np.min(ypts))

# get parameters
popt, pcov = curve_fit(sigmoid, xpts, ypts, p0, **kwargs)

# evaluate fit
yhat = sigmoid(xpts, *popt)

return popt, yhat, pcov


def estdprime(u, r, numbins=100):
"""
Fit a nonlinearity given a 1D stimulus projection u and spiking response r
"""

# pick a set of bins, store centered bins
bins = np.linspace(np.min(u), np.max(u), numbins)
bincenters = bins[:-1] + np.mean(np.diff(bins))*0.5

# bin the raw stimulus distribution
raw, _ = np.histogram(u, bins)

# bin the spike-triggered distribution
data = u[r > 0]
spk, _ = np.histogram(data, bins)

# estimate gaussian parameters
try:
raw_params = fitgaussian(bincenters, raw, (np.mean(u), np.std(u)))[0]
spk_params = fitgaussian(bincenters, spk, (np.mean(data), np.std(data)))[0]
except RuntimeError:
print('Warning: Gaussian curve fit did not converge')
raw_params = (np.mean(u), np.std(u))
spk_params = (np.mean(data), np.std(data))

# estimate d'
return dprime(raw_params, spk_params)


def estnln(u, r, numbins=50):
"""
Fit a nonlinearity given a 1D stimulus projection u and spiking response r
"""
__all__ = ['Sigmoid', 'Binterp']

# the minimum number of data points / bin to keep for fitting
mincount = 2

# bin the raw stimulus distribution
raw, bins = np.histogram(u, numbins)
class Nonlinearity:
def __init__(self):
pass

# bin the spike-triggered distribution
spk, _ = np.histogram(u[r > 0], bins)
def plot(self, start, stop, n=100):
x = np.linspace(start, stop, n)
plt.plot(x, self.predict(x))

# find locations where there are enough data points
locs = np.logical_and((raw > mincount), (spk > mincount))
def fit(self, x, y):
"""Fits the parameters of the nonlinearity
# normalize the two distributions
raw = raw / float(np.sum(raw))
spk = spk / float(np.sum(spk))
Parameters
----------
x : array_like
input to the nonlinearity
# take the ratio of the two distributions
ratio = spk[locs] / raw[locs]
xvals = bins[locs]
y : array_like
output of the nonlinearity (must have the same shape as x)
"""
raise NotImplementedError

def predict(self, x):
"""Computes the value of the function at the given input
# fit a sigmoid to the results
popt, yhat, pcov = fitsigmoid(xvals, ratio)
Parameters
----------
x : array_like
The input to the nonlinearity
Returns
y : array_like
The output of the nonlinearity
"""
raise NotImplementedError

@wraps(predict)
def __call__(self, x):
return self.predict(x)


class Sigmoid(Nonlinearity):
def __init__(self, baseline=0., peak=1., slope=1., threshold=0.):
"""A sigmoidal nonlinearity
Estimates a nonlinearity of the following form:
.. math:: f(x) = \beta + \frac{\alpha}{(1 + \exp(-\gamma * (x - \theta)))}
Usage
-----
>>> f = Sigmoid().fit(x_train, y_train)
>>> yhat = f.predict(x_test) # f(x_test) works as well
Parameters
----------
baseline : float
y-offset (baseline)
peak : float
maximum response
slope : float
gain of the sigmoid
threshold : float
midpoint of the sigmoid
"""
self.init_params = (baseline, peak, slope, threshold)

def fit(self, x, y, **kwargs):
self.params, self.pcov = curve_fit(self._sigmoid, x, y, self.init_params, **kwargs)
return self

@staticmethod
def _sigmoid(x, threshold, slope, peak, baseline):
return baseline + peak / (1 + np.exp(-slope * (x - threshold)))

def predict(self, x):
try:
return self._sigmoid(x, *self.params)
except NameError:
raise RuntimeError('No estimated parameters, call fit() first')


class Binterp(Nonlinearity):
def __init__(self, nbins, method='linear', fill_value='extrapolate'):
"""Interpolated nonlinearity by sorting and binning the data
Given samples (x, y) from the nonlinearity, bin the values using
variable-sized bins with roughly equal counts, and then interpolates
between the mean y-value in each bin using scipy.interpolate.interp1d.
Parameters
----------
nbins : int
How many bins to use along the input axis
# return values
return popt, xvals, ratio, yhat, pcov
method : str, optional
How to do the interpolation (Default: 'linear'). Possible values: 'linear',
'quadratic', 'cubic', 'nearest', 'slinear', 'zero'. See scipy.interpolate.interp1d
for details.
fill_value : str or value, optional
How to fill in values outside the range of bins (Default: 'extrapolate')
Note: 'extrapolate' only works for the 'linear' or 'nearest' methods,
see scipy.interpolate.interp1d for details
"""
self.nbins = nbins
self.method = method
self.fill_value = fill_value

@staticmethod
def _grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n

# TODO: make this more performant
return np.array(list(zip_longest(*args, fillvalue=fillvalue)))

def fit(self, x, y):
binsize = int(x.size / self.nbins)

# sort the x-values and create variable bin edges with equal counts
indices = np.argsort(x)
self.bins = x[indices][::binsize]
y_grouped = self._grouper(y[indices], binsize, fillvalue=np.nan)
self.values = np.nanmean(y_grouped, axis=1)

# set the predict function using scipy.interpolate.interp1d
self.predict = interp1d(self.bins, self.values, kind=self.method, fill_value=self.fill_value)
return self

def predict(self, x):
raise RuntimeError('No estimated parameters, call fit() first')

0 comments on commit 611ee41

Please sign in to comment.